diff --git a/src/main/starlark/builtins_bzl/common/proto/proto_common.bzl b/src/main/starlark/builtins_bzl/common/proto/proto_common.bzl index 056e4ee61951ae..d0d548818b91af 100644 --- a/src/main/starlark/builtins_bzl/common/proto/proto_common.bzl +++ b/src/main/starlark/builtins_bzl/common/proto/proto_common.bzl @@ -32,6 +32,7 @@ ProtoLangToolchainInfo = provider( mnemonic = "(str) Mnemonic to set on the proto compiler action.", allowlist_different_package = """(Target) Allowlist to create lang_proto_library in a different package than proto_library""", + toolchain_type = """(Label) Toolchain type that was used to obtain this info""", ), ) @@ -238,7 +239,7 @@ def _compile( use_default_shell_env = True, resource_set = resource_set, exec_group = experimental_exec_group, - toolchain = None, + toolchain = getattr(proto_lang_toolchain_info, "toolchain_type", None), ) _BAZEL_TOOLS_PREFIX = "external/bazel_tools/" @@ -375,7 +376,6 @@ toolchains = struct( use_toolchain = _use_toolchain, find_toolchain = _find_toolchain, if_legacy_toolchain = _if_legacy_toolchain, - INCOMPATIBLE_ENABLE_PROTO_TOOLCHAIN_RESOLUTION = _builtins.toplevel.proto_common.incompatible_enable_proto_toolchain_resolution(), ) proto_common_do_not_use = struct( @@ -386,4 +386,6 @@ proto_common_do_not_use = struct( experimental_filter_sources = _experimental_filter_sources, get_import_path = _get_import_path, ProtoLangToolchainInfo = ProtoLangToolchainInfo, + INCOMPATIBLE_ENABLE_PROTO_TOOLCHAIN_RESOLUTION = _builtins.toplevel.proto_common.incompatible_enable_proto_toolchain_resolution(), + INCOMPATIBLE_PASS_TOOLCHAIN_TYPE = True, ) diff --git a/src/main/starlark/builtins_bzl/common/proto/proto_lang_toolchain.bzl b/src/main/starlark/builtins_bzl/common/proto/proto_lang_toolchain.bzl index 894c400f53d0f7..cd984e35b8ae86 100644 --- a/src/main/starlark/builtins_bzl/common/proto/proto_lang_toolchain.bzl +++ b/src/main/starlark/builtins_bzl/common/proto/proto_lang_toolchain.bzl @@ -14,7 +14,7 @@ """A Starlark implementation of the proto_lang_toolchain rule.""" -load(":common/proto/proto_common.bzl", "ProtoLangToolchainInfo", "toolchains") +load(":common/proto/proto_common.bzl", "ProtoLangToolchainInfo", "toolchains", proto_common = "proto_common_do_not_use") load(":common/proto/proto_info.bzl", "ProtoInfo") load(":common/proto/proto_semantics.bzl", "semantics") @@ -32,7 +32,7 @@ def _rule_impl(ctx): if ctx.attr.plugin != None: plugin = ctx.attr.plugin[DefaultInfo].files_to_run - if toolchains.INCOMPATIBLE_ENABLE_PROTO_TOOLCHAIN_RESOLUTION: + if proto_common.INCOMPATIBLE_ENABLE_PROTO_TOOLCHAIN_RESOLUTION: proto_compiler = ctx.toolchains[semantics.PROTO_TOOLCHAIN].proto.proto_compiler protoc_opts = ctx.toolchains[semantics.PROTO_TOOLCHAIN].proto.protoc_opts else: @@ -51,6 +51,7 @@ def _rule_impl(ctx): progress_message = ctx.attr.progress_message, mnemonic = ctx.attr.mnemonic, allowlist_different_package = ctx.attr.allowlist_different_package, + toolchain_type = ctx.attr.toolchain_type.label if ctx.attr.toolchain_type else None, ) return [ DefaultInfo(files = depset(), runfiles = ctx.runfiles()), @@ -81,7 +82,8 @@ proto_lang_toolchain = rule( cfg = "exec", providers = [PackageSpecificationInfo], ), - } | ({} if toolchains.INCOMPATIBLE_ENABLE_PROTO_TOOLCHAIN_RESOLUTION else { + "toolchain_type": attr.label(), + } | ({} if proto_common.INCOMPATIBLE_ENABLE_PROTO_TOOLCHAIN_RESOLUTION else { "_proto_compiler": attr.label( cfg = "exec", executable = True, diff --git a/src/main/starlark/builtins_bzl/common/proto/proto_library.bzl b/src/main/starlark/builtins_bzl/common/proto/proto_library.bzl index f14fdee4c36db4..bb6ac60287947e 100644 --- a/src/main/starlark/builtins_bzl/common/proto/proto_library.bzl +++ b/src/main/starlark/builtins_bzl/common/proto/proto_library.bzl @@ -208,7 +208,7 @@ def _write_descriptor_set(ctx, proto_info, deps, exports, descriptor_set): map_each = proto_common.get_import_path, join_with = ":", ) - if toolchains.INCOMPATIBLE_ENABLE_PROTO_TOOLCHAIN_RESOLUTION: + if proto_common.INCOMPATIBLE_ENABLE_PROTO_TOOLCHAIN_RESOLUTION: toolchain = ctx.toolchains[semantics.PROTO_TOOLCHAIN] if not toolchain: fail("Protocol compiler toolchain could not be resolved.") diff --git a/src/test/java/com/google/devtools/build/lib/packages/util/MockProtoSupport.java b/src/test/java/com/google/devtools/build/lib/packages/util/MockProtoSupport.java index 75d8d417ab7e35..fdf0af463d666f 100644 --- a/src/test/java/com/google/devtools/build/lib/packages/util/MockProtoSupport.java +++ b/src/test/java/com/google/devtools/build/lib/packages/util/MockProtoSupport.java @@ -263,6 +263,7 @@ public static void setupWorkspace(MockToolsConfig config) throws IOException { " progress_message = ctx.attr.progress_message,", " mnemonic = ctx.attr.mnemonic,", " allowlist_different_package = None,", + " toolchain_type = '//third_party/bazel_rules/rules_proto/proto:toolchain_type'", " ),", " ),", " ]", @@ -288,7 +289,7 @@ public static void setupWorkspace(MockToolsConfig config) throws IOException { "third_party/bazel_rules/rules_proto/proto/proto_lang_toolchain.bzl", "def proto_lang_toolchain(*, name, toolchain_type = None, exec_compatible_with = [],", " target_compatible_with = [], **attrs):", - " native.proto_lang_toolchain(name = name, **attrs)", + " native.proto_lang_toolchain(name = name, toolchain_type = toolchain_type, **attrs)", " if toolchain_type:", " native.toolchain(", " name = name + '_toolchain',",