diff --git a/src/python/pants/build_graph/build_configuration.py b/src/python/pants/build_graph/build_configuration.py index 9dbb3894663..d5364810d1c 100644 --- a/src/python/pants/build_graph/build_configuration.py +++ b/src/python/pants/build_graph/build_configuration.py @@ -6,7 +6,7 @@ import logging from builtins import object, str -from collections import namedtuple +from collections import OrderedDict, namedtuple from twitter.common.collections import OrderedSet @@ -38,6 +38,7 @@ def __init__(self): self._exposed_context_aware_object_factory_by_alias = {} self._optionables = OrderedSet() self._rules = OrderedSet() + self._union_rules = OrderedDict() def registered_aliases(self): """Return the registered aliases exposed in BUILD files. @@ -152,10 +153,13 @@ def register_rules(self, rules): raise TypeError('The rules must be an iterable, given {!r}'.format(rules)) # "Index" the rules to normalize them and expand their dependencies. - indexed_rules = RuleIndex.create(rules).normalized_rules() + normalized_rules = RuleIndex.create(rules).normalized_rules() + indexed_rules = normalized_rules.rules + union_rules = normalized_rules.union_rules # Store the rules and record their dependency Optionables. self._rules.update(indexed_rules) + self._union_rules.update(union_rules) dependency_optionables = {do for rule in indexed_rules for do in rule.dependency_optionables @@ -165,10 +169,17 @@ def register_rules(self, rules): def rules(self): """Returns the registered rules. - :rtype list + :rtype: list """ return list(self._rules) + def union_rules(self): + """Returns a mapping of registered union base types -> [a list of union member types]. + + :rtype: OrderedDict + """ + return self._union_rules + @memoized_method def _get_addressable_factory(self, target_type, alias): return TargetAddressable.factory(target_type=target_type, alias=alias) diff --git a/src/python/pants/engine/legacy/graph.py b/src/python/pants/engine/legacy/graph.py index 7bd0efbeaf8..ab8420713ad 100644 --- a/src/python/pants/engine/legacy/graph.py +++ b/src/python/pants/engine/legacy/graph.py @@ -5,7 +5,7 @@ from __future__ import absolute_import, division, print_function, unicode_literals import logging -from builtins import str, zip +from builtins import object, str, zip from collections import defaultdict, deque from contextlib import contextmanager from os.path import dirname @@ -24,7 +24,8 @@ from pants.engine.addressable import BuildFileAddresses from pants.engine.fs import PathGlobs, Snapshot from pants.engine.legacy.address_mapper import LegacyAddressMapper -from pants.engine.legacy.structs import BundleAdaptor, BundlesField, SourcesField, TargetAdaptor +from pants.engine.legacy.structs import (BundleAdaptor, BundlesField, HydrateableField, + SourcesField, TargetAdaptor) from pants.engine.mapper import AddressMapper from pants.engine.objects import Collection from pants.engine.parser import SymbolTable, TargetAdaptorContainer @@ -506,9 +507,7 @@ def hydrate_target(target_adaptor_container): target_adaptor = target_adaptor_container.value """Construct a HydratedTarget from a TargetAdaptor and hydrated versions of its adapted fields.""" # Hydrate the fields of the adaptor and re-construct it. - hydrated_fields = yield [(Get(HydratedField, BundlesField, fa) - if type(fa) is BundlesField - else Get(HydratedField, SourcesField, fa)) + hydrated_fields = yield [Get(HydratedField, HydrateableField, fa) for fa in target_adaptor.field_adaptors] kwargs = target_adaptor.kwargs() for field in hydrated_fields: diff --git a/src/python/pants/engine/legacy/structs.py b/src/python/pants/engine/legacy/structs.py index fd456b1541d..ff82754861a 100644 --- a/src/python/pants/engine/legacy/structs.py +++ b/src/python/pants/engine/legacy/structs.py @@ -15,6 +15,7 @@ from pants.engine.addressable import addressable_list from pants.engine.fs import GlobExpansionConjunction, PathGlobs from pants.engine.objects import Locatable +from pants.engine.rules import UnionRule, union from pants.engine.struct import Struct, StructWithDeps from pants.source import wrapped_globs from pants.util.collections_abc_backport import MutableSequence, MutableSet @@ -129,6 +130,10 @@ class Field(object): """A marker for Target(Adaptor) fields for which the engine might perform extra construction.""" +@union +class HydrateableField(object): pass + + class SourcesField( datatype(['address', 'arg', 'filespecs', 'base_globs', 'path_globs', 'validate_fn']), Field @@ -426,3 +431,10 @@ class GlobsWithConjunction(datatype([ @classmethod def for_literal_files(cls, file_paths, spec_path): return cls(Files(*file_paths, spec_path=spec_path), GlobExpansionConjunction.all_match) + + +def rules(): + return [ + UnionRule(HydrateableField, SourcesField), + UnionRule(HydrateableField, BundlesField), + ] diff --git a/src/python/pants/engine/native.py b/src/python/pants/engine/native.py index a7ed59434e2..98bd3884ceb 100644 --- a/src/python/pants/engine/native.py +++ b/src/python/pants/engine/native.py @@ -423,30 +423,30 @@ def extern_generator_send(self, context_handle, func, arg): if isinstance(res, Get): # Get. values = [res.subject] - constraints = [constraint_for(res.product)] + products = [constraint_for(res.product)] tag = 2 elif type(res) in (tuple, list): # GetMulti. values = [g.subject for g in res] - constraints = [constraint_for(g.product) for g in res] + products = [constraint_for(g.product) for g in res] tag = 3 else: # Break. values = [res] - constraints = [] + products = [] tag = 0 except Exception as e: # Throw. val = e val._formatted_exc = traceback.format_exc() values = [val] - constraints = [] + products = [] tag = 1 return ( tag, c.vals_buf([c.to_value(v) for v in values]), - c.vals_buf([c.to_value(v) for v in constraints]) + c.vals_buf([c.to_value(v) for v in products]), ) @_extern_decl('PyResult', ['ExternContext*', 'Handle*', 'Handle**', 'uint64_t']) diff --git a/src/python/pants/engine/rules.py b/src/python/pants/engine/rules.py index 9120763fc94..d50351a484a 100644 --- a/src/python/pants/engine/rules.py +++ b/src/python/pants/engine/rules.py @@ -275,7 +275,9 @@ def resolve_type(name): parents_table=parents_table, ) rule_visitor.visit(rule_func_node) - gets.update(Get(resolve_type(p), resolve_type(s)) for p, s in rule_visitor.gets) + gets.update( + Get.create_statically_for_rule_graph(resolve_type(p), resolve_type(s)) + for p, s in rule_visitor.gets) # For @console_rule, redefine the function to avoid needing a literal return of the output type. if for_goal: @@ -314,6 +316,50 @@ def console_rule(goal_name, input_selectors): return _make_rule(output_type, input_selectors, goal_name, False) +def union(cls): + """A class decorator which other classes can specify that they can resolve to with `UnionRule`. + + Annotating a class with @union allows other classes to use a UnionRule() instance to indicate that + they can be resolved to this base union class. This class will never be instantiated, and should + have no members -- it is used as a tag only, and will be replaced with whatever object is passed + in as the subject of a `yield Get(...)`. See the following example: + + @union + class UnionBase(object): pass + + @rule(B, [Select(X)]) + def get_some_union_type(x): + result = yield Get(ResultType, UnionBase, x.f()) + # ... + + If there exists a single path from (whatever type the expression `x.f()` returns) -> `ResultType` + in the rule graph, the engine will retrieve and execute that path to produce a `ResultType` from + `x.f()`. This requires also that whatever type `x.f()` returns was registered as a union member of + `UnionBase` with a `UnionRule`. + + Unions allow @rule bodies to be written without knowledge of what types may eventually be provided + as input -- rather, they let the engine check that there is a valid path to the desired result. + """ + # TODO: Check that the union base type is used as a tag and nothing else (e.g. no attributes)! + assert isinstance(cls, type) + return type(cls.__name__, (cls,), { + '_is_union': True, + }) + + +class UnionRule(datatype([ + ('union_base', type), + ('union_member', type), +])): + """Specify that an instance of `union_member` can be substituted wherever `union_base` is used.""" + + def __new__(cls, union_base, union_member): + if not getattr(union_base, '_is_union', False): + raise cls.make_type_error('union_base must be a type annotated with @union: was {} (type {})' + .format(union_base, type(union_base).__name__)) + return super(UnionRule, cls).__new__(cls, union_base, union_member) + + class Rule(AbstractClass): """Rules declare how to produce products for the product graph. @@ -375,9 +421,12 @@ def __new__(cls, ) def __str__(self): - return '({}, {!r}, {})'.format(type_or_constraint_repr(self.output_constraint), - self.input_selectors, - self.func.__name__) + return ('({}, {!r}, {}, gets={}, opts={})' + .format(type_or_constraint_repr(self.output_constraint), + self.input_selectors, + self.func.__name__, + self.input_gets, + self.dependency_optionables)) class SingletonRule(datatype(['output_constraint', 'value']), Rule): @@ -420,49 +469,67 @@ def dependency_optionables(self): return tuple() -class RuleIndex(datatype(['rules', 'roots'])): +class RuleIndex(datatype(['rules', 'roots', 'union_rules'])): """Holds a normalized index of Rules used to instantiate Nodes.""" @classmethod - def create(cls, rule_entries): + def create(cls, rule_entries, union_rules=None): """Creates a RuleIndex with tasks indexed by their output type.""" serializable_rules = OrderedDict() serializable_roots = OrderedSet() + union_rules = OrderedDict(union_rules or ()) def add_task(product_type, rule): if product_type not in serializable_rules: serializable_rules[product_type] = OrderedSet() serializable_rules[product_type].add(rule) + def add_root_rule(root_rule): + serializable_roots.add(root_rule) + def add_rule(rule): if isinstance(rule, RootRule): - serializable_roots.add(rule) - return - # TODO: Ensure that interior types work by indexing on the list of types in - # the constraint. This heterogenity has some confusing implications: - # see https://github.com/pantsbuild/pants/issues/4005 - for kind in rule.output_constraint.types: - add_task(kind, rule) - add_task(rule.output_constraint, rule) + add_root_rule(rule) + else: + # TODO: Ensure that interior types work by indexing on the list of types in + # the constraint. This heterogenity has some confusing implications: + # see https://github.com/pantsbuild/pants/issues/4005 + for kind in rule.output_constraint.types: + add_task(kind, rule) + add_task(rule.output_constraint, rule) + + def add_type_transition_rule(union_rule): + # NB: This does not require that union bases be supplied to `def rules():`, as the union type + # is never instantiated! + union_base = union_rule.union_base + assert union_base._is_union + union_member = union_rule.union_member + if union_base not in union_rules: + union_rules[union_base] = OrderedSet() + union_rules[union_base].add(union_member) for entry in rule_entries: if isinstance(entry, Rule): add_rule(entry) + elif isinstance(entry, UnionRule): + add_type_transition_rule(entry) elif hasattr(entry, '__call__'): rule = getattr(entry, 'rule', None) if rule is None: raise TypeError("Expected callable {} to be decorated with @rule.".format(entry)) add_rule(rule) else: - raise TypeError("Unexpected rule type: {}. " - "Rules either extend Rule, or are static functions " - "decorated with @rule.".format(type(entry))) + raise TypeError("""\ +Unexpected rule type: {}. Rules either extend Rule or UnionRule, or are static functions decorated \ +with @rule.""".format(type(entry))) + + return cls(serializable_rules, serializable_roots, union_rules) - return cls(serializable_rules, serializable_roots) + class NormalizedRules(datatype(['rules', 'union_rules'])): pass def normalized_rules(self): rules = OrderedSet(rule for ruleset in self.rules.values() for rule in ruleset) rules.update(self.roots) - return rules + return self.NormalizedRules(rules, self.union_rules) diff --git a/src/python/pants/engine/scheduler.py b/src/python/pants/engine/scheduler.py index c4374c2f713..808b0c39979 100644 --- a/src/python/pants/engine/scheduler.py +++ b/src/python/pants/engine/scheduler.py @@ -55,6 +55,7 @@ def __init__( work_dir, local_store_dir, rules, + union_rules, execution_options, include_trace_on_error=True, validate=True, @@ -66,6 +67,8 @@ def __init__( :param work_dir: The pants work dir. :param local_store_dir: The directory to use for storing the engine's LMDB store in. :param rules: A set of Rules which is used to compute values in the graph. + :param union_rules: A dict mapping union base types to member types so that rules can be written + against abstract union types without knowledge of downstream rulesets. :param execution_options: Execution options for (remote) processes. :param include_trace_on_error: Include the trace through the graph upon encountering errors. :type include_trace_on_error: bool @@ -79,7 +82,7 @@ def __init__( self.include_trace_on_error = include_trace_on_error self._visualize_to_dir = visualize_to_dir # Validate and register all provided and intrinsic tasks. - rule_index = RuleIndex.create(list(rules)) + rule_index = RuleIndex.create(list(rules), union_rules) self._root_subject_types = [r.output_constraint for r in rule_index.roots] # Create the native Scheduler and Session. @@ -188,7 +191,7 @@ def _register_rules(self, rule_index): if type(rule) is SingletonRule: self._register_singleton(output_constraint, rule) elif type(rule) is TaskRule: - self._register_task(output_constraint, rule) + self._register_task(output_constraint, rule, rule_index.union_rules) else: raise ValueError('Unexpected Rule type: {}'.format(rule)) @@ -201,7 +204,7 @@ def _register_singleton(self, output_constraint, rule): self._to_value(rule.value), output_constraint) - def _register_task(self, output_constraint, rule): + def _register_task(self, output_constraint, rule, union_rules): """Register the given TaskRule with the native scheduler.""" func = Function(self._to_key(rule.func)) self._native.lib.tasks_task_begin(self._tasks, func, output_constraint, rule.cacheable) @@ -212,10 +215,21 @@ def _register_task(self, output_constraint, rule): self._native.lib.tasks_add_select(self._tasks, product_constraint) else: raise ValueError('Unrecognized Selector type: {}'.format(selector)) - for get in rule.input_gets: + + def add_get_edge(product, subject): self._native.lib.tasks_add_get(self._tasks, - self._to_constraint(get.product), - TypeId(self._to_id(get.subject))) + self._to_constraint(product), + TypeId(self._to_id(subject))) + + for get in rule.input_gets: + union_members = union_rules.get(get.subject_declared_type, None) + if union_members: + # If the registered subject type is a union, add get edges to all registered union members. + for union_member in union_members: + add_get_edge(get.product, union_member) + else: + # Otherwise, the Get subject is a "concrete" type, so add a single get edge. + add_get_edge(get.product, get.subject_declared_type) self._native.lib.tasks_task_end(self._tasks) def visualize_graph_to_file(self, session, filename): @@ -490,7 +504,6 @@ def run_console_rule(self, product, subject): """ :param product: product type for the request. :param subject: subject for the request. - :param v2_ui: whether to render the v2 engine UI """ request = self.execution_request([product], [subject]) returns, throws = self.execute(request) diff --git a/src/python/pants/engine/selectors.py b/src/python/pants/engine/selectors.py index 8a7ac575bd6..667d29f803c 100644 --- a/src/python/pants/engine/selectors.py +++ b/src/python/pants/engine/selectors.py @@ -5,9 +5,8 @@ from __future__ import absolute_import, division, print_function, unicode_literals import ast -from builtins import str -from pants.util.objects import Exactly, datatype +from pants.util.objects import Exactly, TypeConstraint, datatype def type_or_constraint_repr(constraint): @@ -28,11 +27,11 @@ def constraint_for(type_or_constraint): raise TypeError("Expected a type or constraint: got: {}".format(type_or_constraint)) -class Get(datatype(['product', 'subject'])): +class Get(datatype(['product', 'subject_declared_type', 'subject'])): """Experimental synchronous generator API. May be called equivalently as either: - # verbose form: Get(product_type, subject_type, subject) + # verbose form: Get(product_type, subject_declared_type, subject) # shorthand form: Get(product_type, subject_type(subject)) """ @@ -44,36 +43,66 @@ def extract_constraints(call_node): :return: A tuple of product type id and subject type id. """ def render_args(): - return ', '.join(a.id for a in call_node.args) + return ', '.join( + # Dump the Name's id to simplify output when available, falling back to the name of the + # node's class. + getattr(a, 'id', type(a).__name__) + for a in call_node.args) if len(call_node.args) == 2: product_type, subject_constructor = call_node.args if not isinstance(product_type, ast.Name) or not isinstance(subject_constructor, ast.Call): - raise ValueError('Two arg form of {} expected (product_type, subject_type(subject)), but ' + # TODO(#7114): describe what types of objects are expected in the get call, not just the + # argument names. After #7114 this will be easier because they will just be types! + raise ValueError( + 'Two arg form of {} expected (product_type, subject_type(subject)), but ' 'got: ({})'.format(Get.__name__, render_args())) return (product_type.id, subject_constructor.func.id) elif len(call_node.args) == 3: - product_type, subject_type, _ = call_node.args - if not isinstance(product_type, ast.Name) or not isinstance(subject_type, ast.Name): - raise ValueError('Three arg form of {} expected (product_type, subject_type, subject), but ' + product_type, subject_declared_type, _ = call_node.args + if not isinstance(product_type, ast.Name) or not isinstance(subject_declared_type, ast.Name): + raise ValueError( + 'Three arg form of {} expected (product_type, subject_declared_type, subject), but ' 'got: ({})'.format(Get.__name__, render_args())) - return (product_type.id, subject_type.id) + return (product_type.id, subject_declared_type.id) else: raise ValueError('Invalid {}; expected either two or three args, but ' 'got: ({})'.format(Get.__name__, render_args())) + @classmethod + def create_statically_for_rule_graph(cls, product_type, subject_type): + """Construct a `Get` with a None value. + + This method is used to help make it explicit which `Get` instances are parsed from @rule bodies + and which are instantiated during rule execution. + """ + return cls(product_type, subject_type, None) + def __new__(cls, *args): + # TODO(#7114): Use datatype type checking for these fields! We can wait until after #7114, when + # we can just check that they are types. if len(args) == 2: product, subject = args + + if isinstance(subject, (type, TypeConstraint)): + raise TypeError("""\ +The two-argument form of Get does not accept a type as its second argument. + +args were: Get({args!r}) + +Get.create_statically_for_rule_graph() should be used to generate a Get() for +the `input_gets` field of a rule. If you are using a `yield Get(...)` in a rule +and a type was intended, use the 3-argument version: +Get({product!r}, {subject_type!r}, {subject!r}) +""".format(args=args, product=product, subject_type=type(subject), subject=subject)) + + subject_declared_type = type(subject) elif len(args) == 3: - product, subject_type, subject = args - if type(subject) is not subject_type: - raise TypeError('Declared type did not match actual type for {}({}).'.format( - Get.__name__, ', '.join(str(a) for a in args))) + product, subject_declared_type, subject = args else: - raise Exception('Expected either two or three arguments to {}; got {}.'.format( - Get.__name__, args)) - return super(Get, cls).__new__(cls, product, subject) + raise ValueError('Expected either two or three arguments to {}; got {}.' + .format(Get.__name__, args)) + return super(Get, cls).__new__(cls, product, subject_declared_type, subject) class Params(datatype([('params', tuple)])): diff --git a/src/python/pants/init/engine_initializer.py b/src/python/pants/init/engine_initializer.py index 8a59e829e2c..86d8218a1ff 100644 --- a/src/python/pants/init/engine_initializer.py +++ b/src/python/pants/init/engine_initializer.py @@ -31,6 +31,7 @@ PantsPluginAdaptor, PythonBinaryAdaptor, PythonTargetAdaptor, PythonTestsAdaptor, RemoteSourcesAdaptor, TargetAdaptor) +from pants.engine.legacy.structs import rules as structs_rules from pants.engine.mapper import AddressMapper from pants.engine.parser import SymbolTable from pants.engine.rules import RootRule, SingletonRule @@ -357,6 +358,7 @@ def setup_legacy_graph_extended( create_process_rules() + create_graph_rules(address_mapper) + create_options_parsing_rules() + + structs_rules() + # TODO: This should happen automatically, but most tests (e.g. tests/python/pants_test/auth) fail if it's not here: python_test_runner.rules() + rules @@ -364,12 +366,15 @@ def setup_legacy_graph_extended( goal_map = EngineInitializer._make_goal_map_from_rules(rules) + union_rules = build_configuration.union_rules() + scheduler = Scheduler( native, project_tree, workdir, local_store_dir, rules, + union_rules, execution_options, include_trace_on_error=include_trace_on_error, visualize_to_dir=bootstrap_options.native_engine_visualize_to, diff --git a/src/python/pants/option/optionable.py b/src/python/pants/option/optionable.py index b71c9ef685f..3b239c476d7 100644 --- a/src/python/pants/option/optionable.py +++ b/src/python/pants/option/optionable.py @@ -49,7 +49,7 @@ def signature(cls): output_type=cls.optionable_cls, input_selectors=tuple(), func=partial_construct_optionable, - input_gets=(Get(ScopedOptions, Scope),), + input_gets=(Get.create_statically_for_rule_graph(ScopedOptions, Scope),), dependency_optionables=(cls.optionable_cls,), ) diff --git a/src/rust/engine/src/core.rs b/src/rust/engine/src/core.rs index 1f065fa6fa4..113f5eb9b4f 100644 --- a/src/rust/engine/src/core.rs +++ b/src/rust/engine/src/core.rs @@ -116,9 +116,19 @@ pub type Id = u64; // The type of a python object (which itself has a type, but which is not represented // by a Key, because that would result in a infinitely recursive structure.) #[repr(C)] -#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] +#[derive(Clone, Copy, Eq, Hash, Ord, PartialEq, PartialOrd)] pub struct TypeId(pub Id); +impl fmt::Debug for TypeId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if *self == ANY_TYPE { + write!(f, "Any") + } else { + write!(f, "{}", externs::type_to_str(*self)) + } + } +} + impl fmt::Display for TypeId { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { if *self == ANY_TYPE { @@ -146,12 +156,18 @@ pub struct Function(pub Key); /// Wraps a type id for use as a key in HashMaps and sets. /// #[repr(C)] -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy)] pub struct Key { id: Id, type_id: TypeId, } +impl fmt::Debug for Key { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", externs::key_to_str(self)) + } +} + impl Eq for Key {} impl PartialEq for Key { diff --git a/src/rust/engine/src/externs.rs b/src/rust/engine/src/externs.rs index c81bee215e1..3198473ec4f 100644 --- a/src/rust/engine/src/externs.rs +++ b/src/rust/engine/src/externs.rs @@ -226,20 +226,28 @@ pub fn generator_send(generator: &Value, arg: &Value) -> Result Err(PyResult::failure_from(response.values.unwrap_one())), PyGeneratorResponseType::Get => { let mut interns = INTERNS.write(); - let constraint = TypeConstraint(interns.insert(response.constraints.unwrap_one())); - Ok(GeneratorResponse::Get(Get( - constraint, - interns.insert(response.values.unwrap_one()), - ))) + let product = TypeConstraint(interns.insert(response.products.unwrap_one())); + let subject = interns.insert(response.values.unwrap_one()); + Ok(GeneratorResponse::Get(Get { product, subject })) } PyGeneratorResponseType::GetMulti => { let mut interns = INTERNS.write(); - let continues = response - .constraints - .to_vec() + let PyGeneratorResponse { + values: values_buf, + products: products_buf, + .. + } = response; + let values = values_buf.to_vec(); + let products = products_buf.to_vec(); + assert_eq!(values.len(), products.len()); + let continues: Vec = values .into_iter() - .zip(response.values.to_vec().into_iter()) - .map(|(c, v)| Get(TypeConstraint(interns.insert(c)), interns.insert(v))) + .zip(products.into_iter()) + .map(|(val, prod)| { + let subject = interns.insert(val); + let product = TypeConstraint(interns.insert(prod)); + Get { subject, product } + }) .collect(); Ok(GeneratorResponse::GetMulti(continues)) } @@ -476,11 +484,15 @@ pub enum PyGeneratorResponseType { pub struct PyGeneratorResponse { res_type: PyGeneratorResponseType, values: HandleBuffer, - constraints: HandleBuffer, + products: HandleBuffer, } #[derive(Debug)] -pub struct Get(pub TypeConstraint, pub Key); +pub struct Get { + // TODO(#7114): convert all of these into `TypeId`s! + pub product: TypeConstraint, + pub subject: Key, +} pub enum GeneratorResponse { Break(Value), diff --git a/src/rust/engine/src/lib.rs b/src/rust/engine/src/lib.rs index 9aacd3f2ac9..a6413b5073b 100644 --- a/src/rust/engine/src/lib.rs +++ b/src/rust/engine/src/lib.rs @@ -582,6 +582,7 @@ pub extern "C" fn rule_graph_visualize( let path_str = unsafe { CStr::from_ptr(path_ptr).to_string_lossy().into_owned() }; let path = PathBuf::from(path_str); + // TODO(#7117): we want to represent union types in the graph visualizer somehow!!! let graph = graph_full(scheduler, subject_types.to_vec()); write_to_file(path.as_path(), &graph).unwrap_or_else(|e| { println!("Failed to visualize to {}: {:?}", path.display(), e); @@ -600,6 +601,7 @@ pub extern "C" fn rule_subgraph_visualize( let path_str = unsafe { CStr::from_ptr(path_ptr).to_string_lossy().into_owned() }; let path = PathBuf::from(path_str); + // TODO(#7117): we want to represent union types in the graph visualizer somehow!!! let graph = graph_sub(scheduler, subject_type, product_type); write_to_file(path.as_path(), &graph).unwrap_or_else(|e| { println!("Failed to visualize to {}: {:?}", path.display(), e); diff --git a/src/rust/engine/src/nodes.rs b/src/rust/engine/src/nodes.rs index e320b35ced9..599dda93c10 100644 --- a/src/rust/engine/src/nodes.rs +++ b/src/rust/engine/src/nodes.rs @@ -800,29 +800,33 @@ impl Task { ) -> NodeFuture> { let get_futures = gets .into_iter() - .map(|externs::Get(product, subject)| { + .map(|get| { + let context = context.clone(); + let params = params.clone(); + let entry = entry.clone(); let select_key = rule_graph::SelectKey::JustGet(selectors::Get { - product: product, - subject: *subject.type_id(), + product: get.product, + subject: *get.subject.type_id(), }); let entry = context .core .rule_graph - .edges_for_inner(entry) - .expect("edges for task exist.") - .entry_for(&select_key) - .unwrap_or_else(|| { - panic!( - "{:?} did not declare a dependency on {:?}", - entry, select_key - ) - }) - .clone(); + .edges_for_inner(&entry) + .ok_or_else(|| throw(&format!("no edges for task {:?} exist!", entry))) + .and_then(|edges| { + edges.entry_for(&select_key).cloned().ok_or_else(|| { + throw(&format!( + "{:?} did not declare a dependency on {:?}", + entry, select_key + )) + }) + }); // The subject of the get is a new parameter that replaces an existing param of the same // type. let mut params = params.clone(); - params.put(subject); - Select::new(params, product, entry).run(context.clone()) + params.put(get.subject); + future::result(entry) + .and_then(move |entry| Select::new(params, get.product, entry).run(context.clone())) }) .collect::>(); future::join_all(get_futures).to_boxed() diff --git a/tests/python/pants_test/engine/scheduler_test_base.py b/tests/python/pants_test/engine/scheduler_test_base.py index 200bc6c778b..675bdec4aaa 100644 --- a/tests/python/pants_test/engine/scheduler_test_base.py +++ b/tests/python/pants_test/engine/scheduler_test_base.py @@ -46,6 +46,7 @@ def mk_fs_tree(self, build_root_src=None, ignore_patterns=None, work_dir=None): def mk_scheduler(self, rules=None, + union_rules=None, project_tree=None, work_dir=None, include_trace_on_error=True): @@ -59,6 +60,7 @@ def mk_scheduler(self, work_dir, local_store_dir, rules, + union_rules, DEFAULT_EXECUTION_OPTIONS, include_trace_on_error=include_trace_on_error) return scheduler.new_session() diff --git a/tests/python/pants_test/engine/test_rules.py b/tests/python/pants_test/engine/test_rules.py index c72d1c2694b..b8bfd5eb218 100644 --- a/tests/python/pants_test/engine/test_rules.py +++ b/tests/python/pants_test/engine/test_rules.py @@ -78,8 +78,9 @@ class RuleIndexTest(unittest.TestCase): def test_creation_fails_with_bad_declaration_type(self): with self.assertRaises(TypeError) as cm: RuleIndex.create([A()]) - self.assertEqual("Unexpected rule type: ." - " Rules either extend Rule, or are static functions decorated with @rule.", + self.assertEqual("""\ +Unexpected rule type: . Rules either extend Rule or \ +UnionRule, or are static functions decorated with @rule.""", str(cm.exception)) diff --git a/tests/python/pants_test/engine/test_scheduler.py b/tests/python/pants_test/engine/test_scheduler.py index 586fbfb6e00..e4fedb8ece4 100644 --- a/tests/python/pants_test/engine/test_scheduler.py +++ b/tests/python/pants_test/engine/test_scheduler.py @@ -4,11 +4,14 @@ from __future__ import absolute_import, division, print_function, unicode_literals +import re import unittest from builtins import object, str +from contextlib import contextmanager from textwrap import dedent -from pants.engine.rules import RootRule, rule +from pants.engine.rules import RootRule, UnionRule, rule, union +from pants.engine.scheduler import ExecutionError from pants.engine.selectors import Get, Params, Select from pants.util.objects import datatype from pants_test.engine.util import (assert_equal_with_printing, create_scheduler, @@ -57,6 +60,62 @@ def transitive_coroutine_rule(c): yield D(b) +@union +class UnionBase(object): pass + + +class UnionWrapper(object): + def __init__(self, inner): + self.inner = inner + + +class UnionA(object): + + def a(self): + return A() + + +@rule(A, [Select(UnionA)]) +def select_union_a(union_a): + return union_a.a() + + +class UnionB(object): + + def a(self): + return A() + + +@rule(A, [Select(UnionB)]) +def select_union_b(union_b): + return union_b.a() + + +# TODO: add GetMulti testing for unions! +@rule(A, [Select(UnionWrapper)]) +def a_union_test(union_wrapper): + union_a = yield Get(A, UnionBase, union_wrapper.inner) + yield union_a + + +class TypeCheckFailWrapper(object): + """ + This object wraps another object which will be used to demonstrate a type check failure when the + engine processes a `yield Get(...)` statement. + """ + + def __init__(self, inner): + self.inner = inner + + +@rule(A, [Select(TypeCheckFailWrapper)]) +def a_typecheck_fail_test(wrapper): + # This `yield Get(A, B, ...)` will use the `nested_raise` rule defined above, but it won't get to + # the point of raising since the type check will fail at the Get. + supposedly_a = yield Get(A, B, wrapper.inner) + yield supposedly_a + + class SchedulerTest(TestBase): @classmethod @@ -69,6 +128,16 @@ def rules(cls): consumes_a_and_b, transitive_b_c, transitive_coroutine_rule, + RootRule(UnionWrapper), + UnionRule(UnionBase, UnionA), + RootRule(UnionA), + select_union_a, + UnionRule(union_base=UnionBase, union_member=UnionB), + RootRule(UnionB), + select_union_b, + a_union_test, + a_typecheck_fail_test, + RootRule(TypeCheckFailWrapper), ] def test_use_params(self): @@ -82,7 +151,9 @@ def test_use_params(self): self.assertEquals(result_str, consumes_a_and_b(a, b)) # But not a subset. - with self.assertRaises(Exception): + expected_msg = ("No installed @rules can satisfy Select({}) for input Params(A)" + .format(str.__name__)) + with self.assertRaisesRegexp(Exception, re.escape(expected_msg)): self.scheduler.product_request(str, [Params(a)]) def test_transitive_params(self): @@ -100,6 +171,34 @@ def test_transitive_params(self): # we're just testing transitively resolving products in this file. self.assertTrue(isinstance(result_d, D)) + @contextmanager + def _assert_execution_error(self, expected_msg): + with self.assertRaises(ExecutionError) as cm: + yield + self.assertIn(expected_msg, remove_locations_from_traceback(str(cm.exception))) + + def test_union_rules(self): + a, = self.scheduler.product_request(A, [Params(UnionWrapper(UnionA()))]) + # TODO: figure out what to assert here! + self.assertTrue(isinstance(a, A)) + a, = self.scheduler.product_request(A, [Params(UnionWrapper(UnionB()))]) + self.assertTrue(isinstance(a, A)) + # Fails due to no union relationship from A -> UnionBase. + expected_msg = """\ +Exception: WithDeps(Inner(InnerEntry { params: {UnionWrapper}, rule: Task(Task { product: TypeConstraint(Exactly(A)), clause: [Select { product: Exactly(UnionWrapper) }], gets: [Get { product: TypeConstraint(Exactly(A)), subject: UnionA }, Get { product: TypeConstraint(Exactly(A)), subject: UnionB }], func: Function(), cacheable: true }) })) did not declare a dependency on JustGet(Get { product: TypeConstraint(Exactly(A)), subject: A }) +""" + with self._assert_execution_error(expected_msg): + self.scheduler.product_request(A, [Params(UnionWrapper(A()))]) + + def test_get_type_match_failure(self): + """Test that Get(...)s are now type-checked during rule execution, to allow for union types.""" + expected_msg = """\ +Exception: WithDeps(Inner(InnerEntry { params: {TypeCheckFailWrapper}, rule: Task(Task { product: TypeConstraint(Exactly(A)), clause: [Select { product: Exactly(TypeCheckFailWrapper) }], gets: [Get { product: TypeConstraint(Exactly(A)), subject: B }], func: Function(), cacheable: true }) })) did not declare a dependency on JustGet(Get { product: TypeConstraint(Exactly(A)), subject: A }) +""" + with self._assert_execution_error(expected_msg): + # `a_typecheck_fail_test` above expects `wrapper.inner` to be a `B`. + self.scheduler.product_request(A, [Params(TypeCheckFailWrapper(A()))]) + class SchedulerTraceTest(unittest.TestCase): assert_equal_with_printing = assert_equal_with_printing diff --git a/tests/python/pants_test/engine/test_selectors.py b/tests/python/pants_test/engine/test_selectors.py index cf172fd8763..8d1edcb9771 100644 --- a/tests/python/pants_test/engine/test_selectors.py +++ b/tests/python/pants_test/engine/test_selectors.py @@ -4,8 +4,9 @@ from __future__ import absolute_import, division, print_function, unicode_literals +import ast import unittest -from builtins import object +from builtins import object, str from pants.engine.selectors import Get, Select @@ -15,7 +16,9 @@ class AClass(object): class BClass(object): - pass + + def __eq__(self, other): + return type(self) == type(other) class SubBClass(BClass): @@ -32,8 +35,51 @@ def assert_repr(self, expected, selector): class GetTest(unittest.TestCase): - def test_get(self): - sub_b = SubBClass() + def test_create(self): + # Test the equivalence of the 2-arg and 3-arg versions. + self.assertEqual(Get(AClass, BClass()), + Get(AClass, BClass, BClass())) + with self.assertRaises(TypeError) as cm: - Get(AClass, BClass, sub_b) - self.assertIn("Declared type did not match actual type", str(cm.exception)) + Get(AClass, BClass) + self.assertEqual("""\ +The two-argument form of Get does not accept a type as its second argument. + +args were: Get(({a!r}, {b!r})) + +Get.create_statically_for_rule_graph() should be used to generate a Get() for +the `input_gets` field of a rule. If you are using a `yield Get(...)` in a rule +and a type was intended, use the 3-argument version: +Get({a!r}, {t!r}, {b!r}) +""".format(a=AClass, t=type(BClass), b=BClass), str(cm.exception)) + + with self.assertRaises(ValueError) as cm: + Get(1) + self.assertEqual("Expected either two or three arguments to Get; got (1,).", + str(cm.exception)) + + def _get_call_node(self, input_string): + return ast.parse(input_string).body[0].value + + def test_extract_constraints(self): + parsed_two_arg_call = self._get_call_node("Get(A, B(x))") + self.assertEqual(('A', 'B'), + Get.extract_constraints(parsed_two_arg_call)) + + with self.assertRaises(ValueError) as cm: + Get.extract_constraints(self._get_call_node("Get(1, 2)")) + self.assertEqual(str(cm.exception), """\ +Two arg form of Get expected (product_type, subject_type(subject)), but got: (Num, Num)""") + + parsed_three_arg_call = self._get_call_node("Get(A, B, C(x))") + self.assertEqual(('A', 'B'), + Get.extract_constraints(parsed_three_arg_call)) + + with self.assertRaises(ValueError) as cm: + Get.extract_constraints(self._get_call_node("Get(A, 'asdf', C(x))")) + self.assertEqual(str(cm.exception), """\ +Three arg form of Get expected (product_type, subject_declared_type, subject), but got: (A, Str, Call)""") + + def test_create_statically_for_rule_graph(self): + self.assertEqual(Get(AClass, BClass, None), + Get.create_statically_for_rule_graph(AClass, BClass)) diff --git a/tests/python/pants_test/engine/util.py b/tests/python/pants_test/engine/util.py index 8a7be93daae..7bcf525194a 100644 --- a/tests/python/pants_test/engine/util.py +++ b/tests/python/pants_test/engine/util.py @@ -87,7 +87,7 @@ def init_native(): return Native() -def create_scheduler(rules, validate=True, native=None): +def create_scheduler(rules, union_rules=None, validate=True, native=None): """Create a Scheduler.""" native = native or init_native() return Scheduler( @@ -96,6 +96,7 @@ def create_scheduler(rules, validate=True, native=None): './.pants.d', safe_mkdtemp(), rules, + union_rules, execution_options=DEFAULT_EXECUTION_OPTIONS, validate=validate, )