Skip to content

Commit

Permalink
Implement @unions using embedded Querys.
Browse files Browse the repository at this point in the history
[ci skip-build-wheels]
  • Loading branch information
stuhood committed Sep 21, 2021
1 parent 1186172 commit 26fc055
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 31 deletions.
1 change: 1 addition & 0 deletions src/python/pants/engine/internals/native_engine.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def tasks_task_begin(
) -> None: ...
def tasks_task_end(tasks: PyTasks) -> None: ...
def tasks_add_get(tasks: PyTasks, output: type, input: type) -> None: ...
def tasks_add_union(tasks: PyTasks, output_type: type, input_type: tuple[type, ...]) -> None: ...
def tasks_add_select(tasks: PyTasks, selector: type) -> None: ...
def tasks_add_query(tasks: PyTasks, output_type: type, input_type: tuple[type, ...]) -> None: ...
def execution_add_root_select(
Expand Down
11 changes: 4 additions & 7 deletions src/python/pants/engine/internals/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,18 +659,15 @@ def register_task(rule: TaskRule) -> None:
for selector in rule.input_selectors:
native_engine.tasks_add_select(tasks, selector)

def add_get_edge(product: type, subject: type) -> None:
native_engine.tasks_add_get(tasks, product, subject)

for the_get in rule.input_gets:
if is_union(the_get.input_type):
# If the registered subject type is a union, add Get edges to all registered
# union members.
# Register a union. TODO: See #12934: this should involve an explicit interface
# soon, rather than one being implicitly created with only the provided Param.
for union_member in union_membership.get(the_get.input_type):
add_get_edge(the_get.output_type, union_member)
native_engine.tasks_add_union(tasks, the_get.output_type, (union_member,))
else:
# Otherwise, the Get subject is a "concrete" type, so add a single Get edge.
add_get_edge(the_get.output_type, the_get.input_type)
native_engine.tasks_add_get(tasks, the_get.output_type, the_get.input_type)

native_engine.tasks_task_end(tasks)

Expand Down
20 changes: 20 additions & 0 deletions src/rust/engine/src/externs/interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,11 @@ py_module_initializer!(native_engine, |py, m| {
"tasks_add_get",
py_fn!(py, tasks_add_get(a: PyTasks, b: PyType, c: PyType)),
)?;
m.add(
py,
"tasks_add_union",
py_fn!(py, tasks_add_union(a: PyTasks, b: PyType, c: Vec<PyType>)),
)?;
m.add(
py,
"tasks_add_select",
Expand Down Expand Up @@ -1294,6 +1299,21 @@ fn tasks_add_get(py: Python, tasks_ptr: PyTasks, output: PyType, input: PyType)
})
}

fn tasks_add_union(
py: Python,
tasks_ptr: PyTasks,
output_type: PyType,
input_types: Vec<PyType>,
) -> PyUnitResult {
with_tasks(py, tasks_ptr, |tasks| {
tasks.add_union(
externs::type_for(output_type),
input_types.into_iter().map(externs::type_for).collect(),
);
Ok(None)
})
}

fn tasks_add_select(py: Python, tasks_ptr: PyTasks, selector: PyType) -> PyUnitResult {
with_tasks(py, tasks_ptr, |tasks| {
let selector = externs::type_for(selector);
Expand Down
64 changes: 40 additions & 24 deletions src/rust/engine/src/nodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1076,23 +1076,45 @@ impl Task {
let context = context.clone();
let mut params = params.clone();
let entry = entry.clone();
let dependency_key = selectors::DependencyKey::JustGet(selectors::Get {
output: get.output,
input: *get.input.type_id(),
});
let entry_res = context
.core
.rule_graph
.edges_for_inner(&entry)
.ok_or_else(|| throw(&format!("No edges for task {:?} exist!", entry)))
.and_then(|edges| {
edges.entry_for(&dependency_key).cloned().ok_or_else(|| {
async move {
let dependency_key = selectors::DependencyKey::JustGet(selectors::Get {
output: get.output,
input: *get.input.type_id(),
});
params.put(get.input);

let edges = context
.core
.rule_graph
.edges_for_inner(&entry)
.ok_or_else(|| throw(&format!("No edges for task {:?} exist!", entry)))?;

// See if there is a Get: otherwise, a union (which is executed as a Query).
// See #12934 for further cleanup of this API.
let select = edges
.entry_for(&dependency_key)
.cloned()
.map(|entry| {
// The subject of the get is a new parameter that replaces an existing param of the same
// type.
Select::new(params.clone(), get.output, entry)
})
.or_else(|| {
// Is a union.
let (_, rule_edges) = context
.core
.rule_graph
.find_root(vec![*get.input.type_id()], get.output)
.ok()?;
Some(Select::new_from_edges(params, get.output, &rule_edges))
})
.ok_or_else(|| {
if externs::is_union(get.input_type) {
throw(&format!(
"Invalid Get. Because the second argument to `Get({}, {}, {:?})` is annotated \
with `@union`, the third argument should be a member of that union. Did you \
intend to register `UnionRule({}, {})`? If not, you may be using the wrong \
type ({}) for the third argument.",
with `@union`, the third argument should be a member of that union. Did you \
intend to register `UnionRule({}, {})`? If not, you may be using the wrong \
type ({}) for the third argument.",
get.output,
get.input_type,
get.input,
Expand All @@ -1105,19 +1127,13 @@ impl Task {
// `type(input) != input_type`.
throw(&format!(
"Get({}, {}, {}) was not detected in your @rule body at rule compile time. \
Was the `Get` constructor called in a separate function, or perhaps \
dynamically? If so, it must be inlined into the @rule body.",
Was the `Get` constructor called in a separate function, or perhaps \
dynamically? If so, it must be inlined into the @rule body.",
get.output, get.input_type, get.input
))
}
})
});
// The subject of the get is a new parameter that replaces an existing param of the same
// type.
params.put(get.input);
match entry_res {
Ok(entry) => Select::new(params, get.output, entry).run(context).boxed(),
Err(e) => future::err(e).boxed(),
})?;
select.run(context).await
}
})
.collect::<Vec<_>>();
Expand Down
15 changes: 15 additions & 0 deletions src/rust/engine/src/tasks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ pub struct Task {
pub engine_aware_return_type: bool,
pub clause: Vec<TypeId>,
pub gets: Vec<Get>,
// TODO: This is a preliminary implementation of #12934: we should overhaul naming to
// align Query and @union/Protocol as described there.
pub unions: Vec<Query<Rule>>,
pub func: Function,
pub cacheable: bool,
pub display_info: DisplayInfo,
Expand Down Expand Up @@ -234,6 +237,7 @@ impl Tasks {
engine_aware_return_type,
clause: Vec::new(),
gets: Vec::new(),
unions: Vec::new(),
func,
display_info: DisplayInfo { name, desc, level },
});
Expand All @@ -248,6 +252,17 @@ impl Tasks {
.push(Get { output, input });
}

pub fn add_union(&mut self, product: TypeId, params: Vec<TypeId>) {
let query = Query::new(product, params);
self.queries.insert(query.clone());
self
.preparing
.as_mut()
.expect("Must `begin()` a task creation before adding unions!")
.unions
.push(query);
}

pub fn add_select(&mut self, selector: TypeId) {
self
.preparing
Expand Down

0 comments on commit 26fc055

Please sign in to comment.