Skip to content

Commit

Permalink
union rules work now
Browse files Browse the repository at this point in the history
  • Loading branch information
cosmicexplorer committed Jan 20, 2019
1 parent 5749bb0 commit 166beac
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 45 deletions.
4 changes: 2 additions & 2 deletions src/python/pants/engine/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,11 +439,11 @@ class RuleIndex(datatype(['rules', 'roots', 'union_rules'])):
"""Holds a normalized index of Rules used to instantiate Nodes."""

@classmethod
def create(cls, rule_entries):
def create(cls, rule_entries, union_rules=None):
"""Creates a RuleIndex with tasks indexed by their output type."""
serializable_rules = OrderedDict()
serializable_roots = OrderedSet()
union_rules = OrderedDict()
union_rules = OrderedDict(union_rules or ())

def add_task(product_type, rule):
if product_type not in serializable_rules:
Expand Down
25 changes: 18 additions & 7 deletions src/python/pants/engine/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __init__(
self.include_trace_on_error = include_trace_on_error
self._visualize_to_dir = visualize_to_dir
# Validate and register all provided and intrinsic tasks.
rule_index = RuleIndex.create(list(rules))
rule_index = RuleIndex.create(list(rules), union_rules)
self._root_subject_types = [r.output_constraint for r in rule_index.roots]

# Create the native Scheduler and Session.
Expand All @@ -93,7 +93,7 @@ def __init__(
self._scheduler = native.new_scheduler(
tasks=self._tasks,
root_subject_types=self._root_subject_types,
union_rules=union_rules,
union_rules=rule_index.union_rules,
build_root=project_tree.build_root,
work_dir=work_dir,
local_store_dir=local_store_dir,
Expand Down Expand Up @@ -195,7 +195,7 @@ def _register_rules(self, rule_index):
if type(rule) is SingletonRule:
self._register_singleton(output_constraint, rule)
elif type(rule) is TaskRule:
self._register_task(output_constraint, rule)
self._register_task(output_constraint, rule, rule_index.union_rules)
else:
raise ValueError('Unexpected Rule type: {}'.format(rule))

Expand All @@ -208,7 +208,7 @@ def _register_singleton(self, output_constraint, rule):
self._to_value(rule.value),
output_constraint)

def _register_task(self, output_constraint, rule):
def _register_task(self, output_constraint, rule, union_rules):
"""Register the given TaskRule with the native scheduler."""
func = Function(self._to_key(rule.func))
self._native.lib.tasks_task_begin(self._tasks, func, output_constraint, rule.cacheable)
Expand All @@ -219,10 +219,21 @@ def _register_task(self, output_constraint, rule):
self._native.lib.tasks_add_select(self._tasks, product_constraint)
else:
raise ValueError('Unrecognized Selector type: {}'.format(selector))
for get in rule.input_gets:

def add_get_edge(product, subject):
self._native.lib.tasks_add_get(self._tasks,
self._to_constraint(get.product),
TypeId(self._to_id(get.subject)))
self._to_constraint(product),
TypeId(self._to_id(subject)))

for get in rule.input_gets:
union_members = union_rules.get(get.subject, None)
if union_members:
# If the registered subject type is a union, add get edges to all registered union members.
for union_member in union_members:
add_get_edge(get.product, union_member)
else:
# Otherwise, the Get subject is a "concrete" type, so add a single get edge.
add_get_edge(get.product, get.subject)
self._native.lib.tasks_task_end(self._tasks)

def visualize_graph_to_file(self, session, filename):
Expand Down
29 changes: 1 addition & 28 deletions tests/python/pants_test/engine/test_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from pants.engine.fs import create_fs_rules
from pants.engine.mapper import AddressMapper
from pants.engine.rules import (RootRule, RuleIndex, SingletonRule, _GoalProduct, _RuleVisitor,
console_rule, rule, union, union_rule)
console_rule, rule)
from pants.engine.selectors import Get, Select
from pants.util.objects import Exactly
from pants_test.engine.examples.parsers import JsonParser
Expand Down Expand Up @@ -66,33 +66,6 @@ def a_console_rule_generator(console):
console.print_stdout(str(a))


@union
class UnionBase(object):
pass


class UnionWrapper(object):
def __init__(self, inner):
self.inner = inner


@union_rule(UnionBase)
class UnionA(object):

def a(self):
return A()


@union_rule(UnionBase)
class UnionB(object):

def a(self):
return A()


# TODO: test creating these, but testing for using them can go in test_scheduler.py!
# @rule(A, [Select(UnionWrapper)])
# def a_union_test(union_wrapper):
class RuleTest(unittest.TestCase):
def test_run_rule_console_rule_generator(self):
res = run_rule(a_console_rule_generator, Console(), {
Expand Down
26 changes: 19 additions & 7 deletions tests/python/pants_test/engine/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@

from __future__ import absolute_import, division, print_function, unicode_literals

import re
import unittest
from builtins import object, str
from contextlib import contextmanager
from textwrap import dedent

from pants.engine.rules import RootRule, rule, union, union_rule
from pants.engine.scheduler import ExecutionError
from pants.engine.selectors import Get, Params, Select
from pants.util.objects import datatype
from pants_test.engine.util import (assert_equal_with_printing, create_scheduler,
Expand Down Expand Up @@ -131,7 +134,9 @@ def test_use_params(self):
self.assertEquals(result_str, consumes_a_and_b(a, b))

# But not a subset.
with self.assertRaises(Exception):
expected_msg = ("No installed @rules can satisfy Select({}) for input Params(A)"
.format(str.__name__))
with self.assertRaisesRegexp(Exception, re.escape(expected_msg)):
self.scheduler.product_request(str, [Params(a)])

def test_transitive_params(self):
Expand All @@ -149,15 +154,22 @@ def test_transitive_params(self):
# we're just testing transitively resolving products in this file.
self.assertTrue(isinstance(result_d, D))

@contextmanager
def _assert_execution_error(self, expected_msg):
with self.assertRaises(ExecutionError) as cm:
yield
self.assertIn(expected_msg, remove_locations_from_traceback(str(cm.exception)))

def test_union_rules(self):
a = self.scheduler.product_request(A, [Params(UnionWrapper(UnionA()))])
a, = self.scheduler.product_request(A, [Params(UnionWrapper(UnionA()))])
# TODO: figure out what to assert here!
self.assertIsNotNone(a)
a = self.scheduler.product_request(A, [Params(UnionWrapper(UnionB()))])
self.assertIsNotNone(a)
self.assertTrue(isinstance(a, A))
a, = self.scheduler.product_request(A, [Params(UnionWrapper(UnionB()))])
self.assertTrue(isinstance(a, A))
# Fails due to no union relationship from A -> UnionBase.
a = self.scheduler.product_request(A, [Params(UnionWrapper(A()))])
self.assertIsNotNone(a)
expected_msg = "Exception: None of the registered union members matched the subject. declared union type: TypeConstraint(=UnionBase), union members: {TypeConstraint(=UnionA), TypeConstraint(=UnionB)}, subject: <pants_test.engine.test_scheduler.A object at 0xEEEEEEEEE>"
with self._assert_execution_error(expected_msg):
self.scheduler.product_request(A, [Params(UnionWrapper(A()))])


class SchedulerTraceTest(unittest.TestCase):
Expand Down
3 changes: 2 additions & 1 deletion tests/python/pants_test/engine/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def init_native():
return Native()


def create_scheduler(rules, validate=True, native=None):
def create_scheduler(rules, union_rules=None, validate=True, native=None):
"""Create a Scheduler."""
native = native or init_native()
return Scheduler(
Expand All @@ -96,6 +96,7 @@ def create_scheduler(rules, validate=True, native=None):
'./.pants.d',
safe_mkdtemp(),
rules,
union_rules,
execution_options=DEFAULT_EXECUTION_OPTIONS,
validate=validate,
)
Expand Down

0 comments on commit 166beac

Please sign in to comment.