diff --git a/src/main/java/com/google/devtools/build/lib/bazel/rules/python/BazelPyRuleClasses.java b/src/main/java/com/google/devtools/build/lib/bazel/rules/python/BazelPyRuleClasses.java index 19da7af72903d6..0d4a0980a9f3da 100644 --- a/src/main/java/com/google/devtools/build/lib/bazel/rules/python/BazelPyRuleClasses.java +++ b/src/main/java/com/google/devtools/build/lib/bazel/rules/python/BazelPyRuleClasses.java @@ -32,8 +32,10 @@ import com.google.devtools.build.lib.packages.Attribute.LabelLateBoundDefault; import com.google.devtools.build.lib.packages.RuleClass; import com.google.devtools.build.lib.packages.RuleClass.Builder.RuleClassType; +import com.google.devtools.build.lib.packages.SkylarkProviderIdentifier; import com.google.devtools.build.lib.packages.TriState; import com.google.devtools.build.lib.rules.python.PyCommon; +import com.google.devtools.build.lib.rules.python.PyInfo; import com.google.devtools.build.lib.rules.python.PyRuleClasses; import com.google.devtools.build.lib.rules.python.PyStructUtils; import com.google.devtools.build.lib.rules.python.PythonVersion; @@ -70,6 +72,14 @@ public RuleClass build(RuleClass.Builder builder, RuleDefinitionEnvironment env) .override( builder .copy("deps") + .mandatoryProvidersList( + ImmutableList.of( + // Legacy provider. + // TODO(#7010): Remove this legacy set. + ImmutableList.of( + SkylarkProviderIdentifier.forLegacy(PyStructUtils.PROVIDER_NAME)), + // Modern provider. + ImmutableList.of(PyInfo.PROVIDER.id()))) .legacyMandatoryProviders(PyStructUtils.PROVIDER_NAME) .allowedFileTypes()) /* diff --git a/src/main/java/com/google/devtools/build/lib/rules/python/PyCommon.java b/src/main/java/com/google/devtools/build/lib/rules/python/PyCommon.java index 66729d48ee7647..15f1f9635cde68 100644 --- a/src/main/java/com/google/devtools/build/lib/rules/python/PyCommon.java +++ b/src/main/java/com/google/devtools/build/lib/rules/python/PyCommon.java @@ -562,6 +562,14 @@ public void addCommonTransitiveInfoProviders( NestedSet filesToBuild, NestedSet imports) { + PyProviderUtils.builder() + .setTransitiveSources(transitivePythonSources) + .setUsesSharedLibraries(usesSharedLibraries) + .setImports(imports) + .setHasPy2OnlySources(hasPy2OnlySources) + .setHasPy3OnlySources(hasPy3OnlySources) + .buildAndAddToTarget(builder); + builder .addNativeDeclaredProvider( InstrumentedFilesCollector.collect( @@ -570,15 +578,6 @@ public void addCommonTransitiveInfoProviders( METADATA_COLLECTOR, filesToBuild, /* reportedToActualSources= */ NestedSetBuilder.create(Order.STABLE_ORDER))) - .addSkylarkTransitiveInfo( - PyStructUtils.PROVIDER_NAME, - PyStructUtils.builder() - .setTransitiveSources(transitivePythonSources) - .setUsesSharedLibraries(usesSharedLibraries) - .setImports(imports) - .setHasPy2OnlySources(hasPy2OnlySources) - .setHasPy3OnlySources(hasPy3OnlySources) - .build()) // Python targets are not really compilable. The best we can do is make sure that all // generated source files are ready. .addOutputGroup(OutputGroupInfo.FILES_TO_COMPILE, transitivePythonSources) diff --git a/src/main/java/com/google/devtools/build/lib/rules/python/PyProviderUtils.java b/src/main/java/com/google/devtools/build/lib/rules/python/PyProviderUtils.java index 9e0dbe20e9c55e..94dede4d38cf69 100644 --- a/src/main/java/com/google/devtools/build/lib/rules/python/PyProviderUtils.java +++ b/src/main/java/com/google/devtools/build/lib/rules/python/PyProviderUtils.java @@ -16,6 +16,7 @@ import com.google.devtools.build.lib.actions.Artifact; import com.google.devtools.build.lib.analysis.FileProvider; +import com.google.devtools.build.lib.analysis.RuleConfiguredTargetBuilder; import com.google.devtools.build.lib.analysis.TransitiveInfoCollection; import com.google.devtools.build.lib.collect.nestedset.NestedSet; import com.google.devtools.build.lib.collect.nestedset.NestedSetBuilder; @@ -36,17 +37,30 @@ public class PyProviderUtils { // Disable construction. private PyProviderUtils() {} - /** Returns whether a given target has the py provider. */ - public static boolean hasProvider(TransitiveInfoCollection target) { + /** Returns whether a given target has the {@code PyInfo} provider. */ + public static boolean hasModernProvider(TransitiveInfoCollection target) { + return target.get(PyInfo.PROVIDER) != null; + } + + /** + * Returns the {@code PyInfo} provider from the given target info, or null if the provider is not + * present. + */ + public static PyInfo getModernProvider(TransitiveInfoCollection target) { + return target.get(PyInfo.PROVIDER); + } + + /** Returns whether a given target has the legacy "py" provider. */ + public static boolean hasLegacyProvider(TransitiveInfoCollection target) { return target.get(PyStructUtils.PROVIDER_NAME) != null; } /** - * Returns the struct representing the py provider, from the given target info. + * Returns the struct representing the legacy "py" provider, from the given target info. * * @throws EvalException if the provider does not exist or has the wrong type. */ - public static StructImpl getProvider(TransitiveInfoCollection target) throws EvalException { + public static StructImpl getLegacyProvider(TransitiveInfoCollection target) throws EvalException { Object info = target.get(PyStructUtils.PROVIDER_NAME); if (info == null) { throw new EvalException(/*location=*/ null, "Target does not have 'py' provider"); @@ -71,8 +85,10 @@ public static StructImpl getProvider(TransitiveInfoCollection target) throws Eva // the relevant rules. public static NestedSet getTransitiveSources(TransitiveInfoCollection target) throws EvalException { - if (hasProvider(target)) { - return PyStructUtils.getTransitiveSources(getProvider(target)); + if (hasModernProvider(target)) { + return getModernProvider(target).getTransitiveSources().getSet(Artifact.class); + } else if (hasLegacyProvider(target)) { + return PyStructUtils.getTransitiveSources(getLegacyProvider(target)); } else { NestedSet files = target.getProvider(FileProvider.class).getFilesToBuild(); return NestedSetBuilder.compileOrder() @@ -92,8 +108,10 @@ public static NestedSet getTransitiveSources(TransitiveInfoCollection */ public static boolean getUsesSharedLibraries(TransitiveInfoCollection target) throws EvalException { - if (hasProvider(target)) { - return PyStructUtils.getUsesSharedLibraries(getProvider(target)); + if (hasModernProvider(target)) { + return getModernProvider(target).getUsesSharedLibraries(); + } else if (hasLegacyProvider(target)) { + return PyStructUtils.getUsesSharedLibraries(getLegacyProvider(target)); } else { NestedSet files = target.getProvider(FileProvider.class).getFilesToBuild(); return FileType.contains(files, CppFileTypes.SHARED_LIBRARY); @@ -127,8 +145,10 @@ public static NestedSet getImports(TransitiveInfoCollection target) thro * to false. */ public static boolean getHasPy2OnlySources(TransitiveInfoCollection target) throws EvalException { - if (hasProvider(target)) { - return PyStructUtils.getHasPy2OnlySources(getProvider(target)); + if (hasModernProvider(target)) { + return getModernProvider(target).getHasPy2OnlySources(); + } else if (hasLegacyProvider(target)) { + return PyStructUtils.getHasPy2OnlySources(getLegacyProvider(target)); } else { return false; } @@ -141,10 +161,62 @@ public static boolean getHasPy2OnlySources(TransitiveInfoCollection target) thro * to false. */ public static boolean getHasPy3OnlySources(TransitiveInfoCollection target) throws EvalException { - if (hasProvider(target)) { - return PyStructUtils.getHasPy3OnlySources(getProvider(target)); + if (hasModernProvider(target)) { + return getModernProvider(target).getHasPy3OnlySources(); + } else if (hasLegacyProvider(target)) { + return PyStructUtils.getHasPy3OnlySources(getLegacyProvider(target)); } else { return false; } } + + public static Builder builder() { + return new Builder(); + } + + /** A builder to add both the legacy and modern providers to a configured target. */ + public static class Builder { + private final PyInfo.Builder modernBuilder = PyInfo.builder(); + private final PyStructUtils.Builder legacyBuilder = PyStructUtils.builder(); + + // Use the static builder() method instead. + private Builder() {} + + public Builder setTransitiveSources(NestedSet transitiveSources) { + modernBuilder.setTransitiveSources(transitiveSources); + legacyBuilder.setTransitiveSources(transitiveSources); + return this; + } + + public Builder setUsesSharedLibraries(boolean usesSharedLibraries) { + modernBuilder.setUsesSharedLibraries(usesSharedLibraries); + legacyBuilder.setUsesSharedLibraries(usesSharedLibraries); + return this; + } + + public Builder setImports(NestedSet imports) { + modernBuilder.setImports(imports); + legacyBuilder.setImports(imports); + return this; + } + + public Builder setHasPy2OnlySources(boolean hasPy2OnlySources) { + modernBuilder.setHasPy2OnlySources(hasPy2OnlySources); + legacyBuilder.setHasPy2OnlySources(hasPy2OnlySources); + return this; + } + + public Builder setHasPy3OnlySources(boolean hasPy3OnlySources) { + modernBuilder.setHasPy3OnlySources(hasPy3OnlySources); + legacyBuilder.setHasPy3OnlySources(hasPy3OnlySources); + return this; + } + + public RuleConfiguredTargetBuilder buildAndAddToTarget( + RuleConfiguredTargetBuilder targetBuilder) { + targetBuilder.addSkylarkTransitiveInfo(PyStructUtils.PROVIDER_NAME, legacyBuilder.build()); + targetBuilder.addNativeDeclaredProvider(modernBuilder.build()); + return targetBuilder; + } + } } diff --git a/src/test/java/com/google/devtools/build/lib/rules/python/BUILD b/src/test/java/com/google/devtools/build/lib/rules/python/BUILD index 422e6b271ded41..e7aafc5a8d606b 100644 --- a/src/test/java/com/google/devtools/build/lib/rules/python/BUILD +++ b/src/test/java/com/google/devtools/build/lib/rules/python/BUILD @@ -225,8 +225,10 @@ java_test( name = "PythonStarlarkApiTest", srcs = ["PythonStarlarkApiTest.java"], deps = [ + "//src/main/java/com/google/devtools/build/lib:build-base", "//src/main/java/com/google/devtools/build/lib:packages-internal", "//src/main/java/com/google/devtools/build/lib:python-rules", + "//src/main/java/com/google/devtools/build/lib/actions", "//src/test/java/com/google/devtools/build/lib:analysis_testutil", "//third_party:junit4", "//third_party:truth", diff --git a/src/test/java/com/google/devtools/build/lib/rules/python/PyBaseConfiguredTargetTestBase.java b/src/test/java/com/google/devtools/build/lib/rules/python/PyBaseConfiguredTargetTestBase.java index c5579d41b4cdbd..a1a3794049732d 100644 --- a/src/test/java/com/google/devtools/build/lib/rules/python/PyBaseConfiguredTargetTestBase.java +++ b/src/test/java/com/google/devtools/build/lib/rules/python/PyBaseConfiguredTargetTestBase.java @@ -57,7 +57,8 @@ public void badSrcsVersionValue() throws Exception { @Test public void goodSrcsVersionValue() throws Exception { - scratch.file("pkg/BUILD", + scratch.file( + "pkg/BUILD", ruleName + "(", " name = 'foo',", " srcs_version = 'PY2',", @@ -89,7 +90,8 @@ public void srcsVersionClashesWithForcePythonFlagUnderOldSemantics() throws Exce @Test public void versionIs2IfUnspecified() throws Exception { ensureDefaultIsPY2(); - scratch.file("pkg/BUILD", + scratch.file( + "pkg/BUILD", // ruleName + "(", " name = 'foo',", " srcs = ['foo.py'])"); @@ -105,7 +107,8 @@ public void versionIs3IfForcedByFlagUnderOldSemantics() throws Exception { // test. ensureDefaultIsPY2(); useConfiguration("--experimental_allow_python_version_transitions=false", "--force_python=PY3"); - scratch.file("pkg/BUILD", + scratch.file( + "pkg/BUILD", // ruleName + "(", " name = 'foo',", " srcs = ['foo.py'])"); @@ -125,7 +128,8 @@ public void packageNameCannotHaveHyphen() throws Exception { @Test public void srcsPackageNameCannotHaveHyphen() throws Exception { - scratch.file("pkg-hyphenated/BUILD", + scratch.file( + "pkg-hyphenated/BUILD", // "exports_files(['bar.py'])"); checkError("otherpkg", "foo", // error: @@ -135,4 +139,92 @@ public void srcsPackageNameCannotHaveHyphen() throws Exception { " name = 'foo',", " srcs = ['foo.py', '//pkg-hyphenated:bar.py'])"); } + + @Test + public void producesBothModernAndLegacyProviders() throws Exception { + scratch.file( + "pkg/BUILD", // + ruleName + "(", + " name = 'foo',", + " srcs = ['foo.py'])"); + ConfiguredTarget target = getConfiguredTarget("//pkg:foo"); + assertThat(target.get(PyInfo.PROVIDER)).isNotNull(); + assertThat(target.get(PyStructUtils.PROVIDER_NAME)).isNotNull(); + } + + @Test + public void consumesLegacyProvider() throws Exception { + scratch.file( + "pkg/rules.bzl", + "def _myrule_impl(ctx):", + " return struct(py=struct(transitive_sources=depset([])))", + "myrule = rule(", + " implementation = _myrule_impl,", + ")"); + scratch.file( + "pkg/BUILD", + "load(':rules.bzl', 'myrule')", + "myrule(", + " name = 'dep',", + ")", + ruleName + "(", + " name = 'foo',", + " srcs = ['foo.py'],", + " deps = [':dep'],", + ")"); + ConfiguredTarget target = getConfiguredTarget("//pkg:foo"); + assertThat(target).isNotNull(); + assertNoEvents(); + } + + @Test + public void consumesModernProvider() throws Exception { + scratch.file( + "pkg/rules.bzl", + "def _myrule_impl(ctx):", + " return [PyInfo(transitive_sources=depset([]))]", + "myrule = rule(", + " implementation = _myrule_impl,", + ")"); + scratch.file( + "pkg/BUILD", + "load(':rules.bzl', 'myrule')", + "myrule(", + " name = 'dep',", + ")", + ruleName + "(", + " name = 'foo',", + " srcs = ['foo.py'],", + " deps = [':dep'],", + ")"); + ConfiguredTarget target = getConfiguredTarget("//pkg:foo"); + assertThat(target).isNotNull(); + assertNoEvents(); + } + + @Test + public void requiresProvider() throws Exception { + scratch.file( + "pkg/rules.bzl", + "def _myrule_impl(ctx):", + " return []", + "myrule = rule(", + " implementation = _myrule_impl,", + ")"); + checkError( + "pkg", + "foo", + // error: + "'//pkg:dep' does not have mandatory providers", + // build file: + "load(':rules.bzl', 'myrule')", + "myrule(", + " name = 'dep',", + ")", + ruleName + "(", + " name = 'foo',", + " srcs = ['foo.py'],", + " deps = [':dep'],", + ")"); + } } diff --git a/src/test/java/com/google/devtools/build/lib/rules/python/PyExecutableConfiguredTargetTestBase.java b/src/test/java/com/google/devtools/build/lib/rules/python/PyExecutableConfiguredTargetTestBase.java index f548df9ff52c45..338a6fc13a612d 100644 --- a/src/test/java/com/google/devtools/build/lib/rules/python/PyExecutableConfiguredTargetTestBase.java +++ b/src/test/java/com/google/devtools/build/lib/rules/python/PyExecutableConfiguredTargetTestBase.java @@ -107,7 +107,7 @@ protected void assertPythonVersionIs_UnderNewConfigs( } } - private String join(String... lines) { + private static String join(String... lines) { return String.join("\n", lines); } diff --git a/src/test/java/com/google/devtools/build/lib/rules/python/PyProviderUtilsTest.java b/src/test/java/com/google/devtools/build/lib/rules/python/PyProviderUtilsTest.java index 5edc536fe8c798..9b4747ed0a0272 100644 --- a/src/test/java/com/google/devtools/build/lib/rules/python/PyProviderUtilsTest.java +++ b/src/test/java/com/google/devtools/build/lib/rules/python/PyProviderUtilsTest.java @@ -29,6 +29,10 @@ @RunWith(JUnit4.class) public class PyProviderUtilsTest extends BuildViewTestCase { + private static String join(String... lines) { + return String.join("\n", lines); + } + private static void assertThrowsEvalExceptionContaining( ThrowingRunnable runnable, String message) { assertThat(assertThrows(EvalException.class, runnable)).hasMessageThat().contains(message); @@ -37,18 +41,22 @@ private static void assertThrowsEvalExceptionContaining( /** * Creates {@code //pkg:target} as an instance of a trivial rule having the given implementation * function body. + * + *

The body is formed by joining the input strings with newlines, then inserting 4 spaces of + * indentation before the first line and each newline. This allows a single string argument to + * contain multiple pre-joined lines and still be indented correctly. */ - private void declareTargetWithImplementation(String... lines) throws Exception { - String body; - if (lines.length == 0) { - body = " pass"; + private void declareTargetWithImplementation(String... input) throws Exception { + String indentedBody; + if (input.length == 0) { + indentedBody = " pass"; } else { - body = " " + String.join("\n ", lines); + indentedBody = " " + String.join("\n", input).replace("\n", "\n "); } scratch.file( "pkg/rules.bzl", "def _myrule_impl(ctx):", - body, + indentedBody, "myrule = rule(", " implementation = _myrule_impl,", ")"); @@ -68,59 +76,120 @@ private ConfiguredTarget getTarget() throws Exception { } @Test - public void getAndHasProvider_Present() throws Exception { + public void getAndHasModernProvider_Present() throws Exception { + declareTargetWithImplementation( // + "return [PyInfo(transitive_sources=depset([]))]"); + assertThat(PyProviderUtils.hasModernProvider(getTarget())).isTrue(); + assertThat(PyProviderUtils.getModernProvider(getTarget())).isNotNull(); + } + + @Test + public void getAndHasModernProvider_Absent() throws Exception { + declareTargetWithImplementation( // + "return []"); + assertThat(PyProviderUtils.hasModernProvider(getTarget())).isFalse(); + assertThat(PyProviderUtils.getModernProvider(getTarget())).isNull(); + } + + @Test + public void getAndHasLegacyProvider_Present() throws Exception { declareTargetWithImplementation( // "return struct(py=struct())"); // Accessor doesn't actually check fields - assertThat(PyProviderUtils.hasProvider(getTarget())).isTrue(); - assertThat(PyProviderUtils.getProvider(getTarget())).isNotNull(); + assertThat(PyProviderUtils.hasLegacyProvider(getTarget())).isTrue(); + assertThat(PyProviderUtils.getLegacyProvider(getTarget())).isNotNull(); } @Test - public void getAndHasProvider_Absent() throws Exception { + public void getAndHasLegacyProvider_Absent() throws Exception { declareTargetWithImplementation( // "return []"); - assertThat(PyProviderUtils.hasProvider(getTarget())).isFalse(); + assertThat(PyProviderUtils.hasLegacyProvider(getTarget())).isFalse(); assertThrowsEvalExceptionContaining( - () -> PyProviderUtils.getProvider(getTarget()), "Target does not have 'py' provider"); + () -> PyProviderUtils.getLegacyProvider(getTarget()), "Target does not have 'py' provider"); } @Test - public void getProvider_NotAStruct() throws Exception { + public void getLegacyProvider_NotAStruct() throws Exception { declareTargetWithImplementation( // "return struct(py=123)"); assertThrowsEvalExceptionContaining( - () -> PyProviderUtils.getProvider(getTarget()), "'py' provider should be a struct"); + () -> PyProviderUtils.getLegacyProvider(getTarget()), "'py' provider should be a struct"); } + private static final String TRANSITIVE_SOURCES_SETUP_CODE = + join( + "afile = ctx.actions.declare_file('a.py')", + "ctx.actions.write(output=afile, content='a')", + "bfile = ctx.actions.declare_file('b.py')", + "ctx.actions.write(output=bfile, content='b')", + "modern_info = PyInfo(transitive_sources = depset(direct=[afile]))", + "legacy_info = struct(transitive_sources = depset(direct=[bfile]))"); + @Test - public void getTransitiveSources_Provider() throws Exception { - declareTargetWithImplementation( - "afile = ctx.actions.declare_file('a.py')", - "ctx.actions.write(output=afile, content='a')", - "info = struct(transitive_sources = depset(direct=[afile]))", - "return struct(py=info)"); + public void getTransitiveSources_ModernProvider() throws Exception { + declareTargetWithImplementation( // + TRANSITIVE_SOURCES_SETUP_CODE, // + "return [modern_info]"); + assertThat(PyProviderUtils.getTransitiveSources(getTarget())) + .containsExactly(getBinArtifact("a.py", getTarget())); + } + + @Test + public void getTransitiveSources_LegacyProvider() throws Exception { + declareTargetWithImplementation( // + TRANSITIVE_SOURCES_SETUP_CODE, // + "return struct(py=legacy_info)"); + assertThat(PyProviderUtils.getTransitiveSources(getTarget())) + .containsExactly(getBinArtifact("b.py", getTarget())); + } + + @Test + public void getTransitiveSources_BothProviders() throws Exception { + declareTargetWithImplementation( // + TRANSITIVE_SOURCES_SETUP_CODE, // + "return struct(py=legacy_info, providers=[modern_info])"); assertThat(PyProviderUtils.getTransitiveSources(getTarget())) .containsExactly(getBinArtifact("a.py", getTarget())); } @Test public void getTransitiveSources_NoProvider() throws Exception { - declareTargetWithImplementation( - "afile = ctx.actions.declare_file('a.py')", - "ctx.actions.write(output=afile, content='a')", + declareTargetWithImplementation( // + TRANSITIVE_SOURCES_SETUP_CODE, // "return [DefaultInfo(files=depset(direct=[afile]))]"); assertThat(PyProviderUtils.getTransitiveSources(getTarget())) .containsExactly(getBinArtifact("a.py", getTarget())); } + private static final String USES_SHARED_LIBRARIES_SETUP_CODE = + join( + "modern_info = PyInfo(transitive_sources = depset([]), uses_shared_libraries=False)", + "legacy_info = struct(transitive_sources = depset([]), uses_shared_libraries=True)"); + @Test - public void getUsesSharedLibraries_Provider() throws Exception { - declareTargetWithImplementation( - "info = struct(transitive_sources = depset([]), uses_shared_libraries=True)", - "return struct(py=info)"); + public void getUsesSharedLibraries_ModernProvider() throws Exception { + declareTargetWithImplementation( // + USES_SHARED_LIBRARIES_SETUP_CODE, // + "return [modern_info]"); + assertThat(PyProviderUtils.getUsesSharedLibraries(getTarget())).isFalse(); + } + + @Test + public void getUsesSharedLibraries_LegacyProvider() throws Exception { + declareTargetWithImplementation( // + USES_SHARED_LIBRARIES_SETUP_CODE, // + "return struct(py=legacy_info)"); assertThat(PyProviderUtils.getUsesSharedLibraries(getTarget())).isTrue(); } + @Test + public void getUsesSharedLibraries_BothProviders() throws Exception { + declareTargetWithImplementation( // + USES_SHARED_LIBRARIES_SETUP_CODE, // + "return struct(py=legacy_info, providers=[modern_info])"); + assertThat(PyProviderUtils.getUsesSharedLibraries(getTarget())).isFalse(); + } + @Test public void getUsesSharedLibraries_NoProvider() throws Exception { declareTargetWithImplementation( @@ -133,14 +202,35 @@ public void getUsesSharedLibraries_NoProvider() throws Exception { // TODO(#7054): Add tests for getImports once we actually read the Python provider instead of the // specialized PythonImportsProvider. + private static final String HAS_PY2_ONLY_SOURCES_SETUP_CODE = + join( + "modern_info = PyInfo(transitive_sources = depset([]), has_py2_only_sources=False)", + "legacy_info = struct(transitive_sources = depset([]), has_py2_only_sources=True)"); + @Test - public void getHasPy2OnlySources_Provider() throws Exception { - declareTargetWithImplementation( - "info = struct(transitive_sources = depset([]), has_py2_only_sources=True)", - "return struct(py=info)"); + public void getHasPy2OnlySources_ModernProvider() throws Exception { + declareTargetWithImplementation( // + HAS_PY2_ONLY_SOURCES_SETUP_CODE, // + "return [modern_info]"); + assertThat(PyProviderUtils.getHasPy2OnlySources(getTarget())).isFalse(); + } + + @Test + public void getHasPy2OnlySources_LegacyProvider() throws Exception { + declareTargetWithImplementation( // + HAS_PY2_ONLY_SOURCES_SETUP_CODE, // + "return struct(py=legacy_info)"); assertThat(PyProviderUtils.getHasPy2OnlySources(getTarget())).isTrue(); } + @Test + public void getHasPy2OnlySources_BothProviders() throws Exception { + declareTargetWithImplementation( // + HAS_PY2_ONLY_SOURCES_SETUP_CODE, // + "return struct(py=legacy_info, providers=[modern_info])"); + assertThat(PyProviderUtils.getHasPy2OnlySources(getTarget())).isFalse(); + } + @Test public void getHasPy2OnlySources_NoProvider() throws Exception { declareTargetWithImplementation( // @@ -148,14 +238,35 @@ public void getHasPy2OnlySources_NoProvider() throws Exception { assertThat(PyProviderUtils.getHasPy2OnlySources(getTarget())).isFalse(); } + private static final String HAS_PY3_ONLY_SOURCES_SETUP_CODE = + join( + "modern_info = PyInfo(transitive_sources = depset([]), has_py3_only_sources=False)", + "legacy_info = struct(transitive_sources = depset([]), has_py3_only_sources=True)"); + @Test - public void getHasPy3OnlySources_Provider() throws Exception { - declareTargetWithImplementation( - "info = struct(transitive_sources = depset([]), has_py3_only_sources=True)", - "return struct(py=info)"); + public void getHasPy3OnlySources_ModernProvider() throws Exception { + declareTargetWithImplementation( // + HAS_PY3_ONLY_SOURCES_SETUP_CODE, // + "return [modern_info]"); + assertThat(PyProviderUtils.getHasPy3OnlySources(getTarget())).isFalse(); + } + + @Test + public void getHasPy3OnlySources_LegacyProvider() throws Exception { + declareTargetWithImplementation( // + HAS_PY3_ONLY_SOURCES_SETUP_CODE, // + "return struct(py=legacy_info)"); assertThat(PyProviderUtils.getHasPy3OnlySources(getTarget())).isTrue(); } + @Test + public void getHasPy3OnlySources_BothProviders() throws Exception { + declareTargetWithImplementation( // + HAS_PY3_ONLY_SOURCES_SETUP_CODE, // + "return struct(py=legacy_info, providers=[modern_info])"); + assertThat(PyProviderUtils.getHasPy3OnlySources(getTarget())).isFalse(); + } + @Test public void getHasPy3OnlySources_NoProvider() throws Exception { declareTargetWithImplementation( // diff --git a/src/test/java/com/google/devtools/build/lib/rules/python/PythonStarlarkApiTest.java b/src/test/java/com/google/devtools/build/lib/rules/python/PythonStarlarkApiTest.java index 71ed741fb2b646..db8c29cec8b42e 100644 --- a/src/test/java/com/google/devtools/build/lib/rules/python/PythonStarlarkApiTest.java +++ b/src/test/java/com/google/devtools/build/lib/rules/python/PythonStarlarkApiTest.java @@ -16,6 +16,8 @@ import static com.google.common.truth.Truth.assertThat; +import com.google.devtools.build.lib.actions.Artifact; +import com.google.devtools.build.lib.analysis.ConfiguredTarget; import com.google.devtools.build.lib.analysis.util.BuildViewTestCase; import com.google.devtools.build.lib.packages.StructImpl; import org.junit.Test; @@ -48,12 +50,17 @@ private void defineUserlibRule() throws Exception { " has_py3_only_sources = \\", " any([py.has_py3_only_sources for py in dep_infos]) or \\", " ctx.attr.has_py3_only_sources", - " info = struct(", + " legacy_info = struct(", " transitive_sources = transitive_sources,", " uses_shared_libraries = uses_shared_libraries,", " has_py2_only_sources = has_py2_only_sources,", " has_py3_only_sources = has_py3_only_sources)", - " return struct(py=info)", + " modern_info = PyInfo(", + " transitive_sources = transitive_sources,", + " uses_shared_libraries = uses_shared_libraries,", + " has_py2_only_sources = has_py2_only_sources,", + " has_py3_only_sources = has_py3_only_sources)", + " return struct(py=legacy_info, providers=[modern_info])", "", "userlib = rule(", " implementation = _userlib_impl,", @@ -92,14 +99,26 @@ public void librarySandwich() throws Exception { " srcs = ['upperuserlib.py'],", " deps = [':pylib'],", ")"); - StructImpl info = PyProviderUtils.getProvider(getConfiguredTarget("//pkg:upperuserlib")); - assertThat(PyStructUtils.getTransitiveSources(info)) + ConfiguredTarget target = getConfiguredTarget("//pkg:upperuserlib"); + StructImpl legacyInfo = PyProviderUtils.getLegacyProvider(target); + PyInfo modernInfo = PyProviderUtils.getModernProvider(target); + + assertThat(PyStructUtils.getTransitiveSources(legacyInfo)) + .containsExactly( + getSourceArtifact("pkg/loweruserlib.py"), + getSourceArtifact("pkg/pylib.py"), + getSourceArtifact("pkg/upperuserlib.py")); + assertThat(PyStructUtils.getUsesSharedLibraries(legacyInfo)).isTrue(); + assertThat(PyStructUtils.getHasPy2OnlySources(legacyInfo)).isTrue(); + assertThat(PyStructUtils.getHasPy3OnlySources(legacyInfo)).isTrue(); + + assertThat(modernInfo.getTransitiveSources().getSet(Artifact.class)) .containsExactly( getSourceArtifact("pkg/loweruserlib.py"), getSourceArtifact("pkg/pylib.py"), getSourceArtifact("pkg/upperuserlib.py")); - assertThat(PyStructUtils.getUsesSharedLibraries(info)).isTrue(); - assertThat(PyStructUtils.getHasPy2OnlySources(info)).isTrue(); - assertThat(PyStructUtils.getHasPy3OnlySources(info)).isTrue(); + assertThat(modernInfo.getUsesSharedLibraries()).isTrue(); + assertThat(modernInfo.getHasPy2OnlySources()).isTrue(); + assertThat(modernInfo.getHasPy3OnlySources()).isTrue(); } } diff --git a/src/test/java/com/google/devtools/build/lib/skylark/SkylarkRuleContextTest.java b/src/test/java/com/google/devtools/build/lib/skylark/SkylarkRuleContextTest.java index c38ad4b9e229f4..b79dd4c1887b26 100644 --- a/src/test/java/com/google/devtools/build/lib/skylark/SkylarkRuleContextTest.java +++ b/src/test/java/com/google/devtools/build/lib/skylark/SkylarkRuleContextTest.java @@ -492,7 +492,7 @@ public void shouldGetPrerequisites() throws Exception { TransitiveInfoCollection tic1 = (TransitiveInfoCollection) ((SkylarkList) result).get(0); assertThat(JavaInfo.getProvider(JavaSourceJarsProvider.class, tic1)).isNotNull(); // Check an unimplemented provider too - assertThat(PyProviderUtils.hasProvider(tic1)).isFalse(); + assertThat(PyProviderUtils.hasLegacyProvider(tic1)).isFalse(); } @Test