Skip to content

Commit

Permalink
feat!: disallow opaque ops during validation (#1431)
Browse files Browse the repository at this point in the history
Update python to pass extensions as package when validating

Use cached signatures from opaque when resolving if missing binary
compute.
Closes #1362


BREAKING CHANGE: HUGRs containing opaque operations that don't point to
an extension in the registry will fail to validate. Use `Package` to
pack extensions with HUGRs for serialisation.

---------

Co-authored-by: Alec Edgington <54802828+cqc-alec@users.noreply.github.com>
  • Loading branch information
ss2165 and cqc-alec authored Aug 15, 2024
1 parent 8e8bba5 commit fbbb805
Show file tree
Hide file tree
Showing 14 changed files with 207 additions and 76 deletions.
16 changes: 6 additions & 10 deletions hugr-core/src/builder/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,13 +247,13 @@ mod test {
use crate::utils::test_quantum_extension::{
self, cx_gate, h_gate, measure, q_alloc, q_discard, rz_f64,
};
use crate::Extension;
use crate::{
builder::{
test::{build_main, NAT, QB},
DataflowSubContainer,
},
extension::prelude::BOOL_T,
ops::custom::OpaqueOp,
type_row,
types::Signature,
};
Expand Down Expand Up @@ -296,19 +296,15 @@ mod test {

#[test]
fn with_nonlinear_and_outputs() {
let missing_ext: ExtensionId = "MissingExt".try_into().unwrap();
let my_custom_op = OpaqueOp::new(
missing_ext.clone(),
"MyOp",
"unknown op".to_string(),
vec![],
Signature::new(vec![QB, NAT], vec![QB]),
);
let my_ext_name: ExtensionId = "MyExt".try_into().unwrap();
let mut my_ext = Extension::new_test(my_ext_name.clone());
let my_custom_op = my_ext.simple_ext_op("MyOp", Signature::new(vec![QB, NAT], vec![QB]));

let build_res = build_main(
Signature::new(type_row![QB, QB, NAT], type_row![QB, QB, BOOL_T])
.with_extension_delta(ExtensionSet::from_iter([
test_quantum_extension::EXTENSION_ID,
missing_ext,
my_ext_name,
]))
.into(),
|mut f_build| {
Expand Down
12 changes: 12 additions & 0 deletions hugr-core/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,18 @@ pub mod test {
pub(crate) fn new_test(name: ExtensionId) -> Self {
Self::new(name, Version::new(0, 0, 0))
}

/// Add a simple OpDef to the extension and return an extension op for it.
/// No description, no type parameters.
pub(crate) fn simple_ext_op(
&mut self,
name: &str,
signature: impl Into<SignatureFunc>,
) -> ExtensionOp {
self.add_op(name.into(), "".to_string(), signature).unwrap();
self.instantiate_extension_op(name, [], &PRELUDE_REGISTRY)
.unwrap()
}
}

#[test]
Expand Down
5 changes: 2 additions & 3 deletions hugr-core/src/extension/op_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,9 +241,8 @@ impl SignatureFunc {
(&temp, other_args)
}
SignatureFunc::MissingComputeFunc => return Err(SignatureError::MissingComputeFunc),
SignatureFunc::MissingValidateFunc(_) => {
return Err(SignatureError::MissingValidateFunc)
}
// TODO raise warning: https://github.com/CQCL/hugr/issues/1432
SignatureFunc::MissingValidateFunc(ts) => (ts, args),
};

let mut res = pf.instantiate(args, exts)?;
Expand Down
15 changes: 8 additions & 7 deletions hugr-core/src/hugr/rewrite/replace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -450,11 +450,11 @@ mod test {
DataflowSubContainer, HugrBuilder, SubContainer,
};
use crate::extension::prelude::{BOOL_T, USIZE_T};
use crate::extension::{ExtensionId, ExtensionRegistry, PRELUDE, PRELUDE_REGISTRY};
use crate::extension::{ExtensionRegistry, PRELUDE, PRELUDE_REGISTRY};
use crate::hugr::internal::HugrMutInternals;
use crate::hugr::rewrite::replace::WhichHugr;
use crate::hugr::{HugrMut, Rewrite};
use crate::ops::custom::{ExtensionOp, OpaqueOp};
use crate::ops::custom::ExtensionOp;
use crate::ops::dataflow::DataflowOpTrait;
use crate::ops::handle::{BasicBlockID, ConstID, NodeHandle};
use crate::ops::{self, Case, DataflowBlock, OpTag, OpType, DFG};
Expand Down Expand Up @@ -639,12 +639,13 @@ mod test {

#[test]
fn test_invalid() {
let unknown_ext: ExtensionId = "unknown_ext".try_into().unwrap();
let mut new_ext = crate::Extension::new_test("new_ext".try_into().unwrap());
let ext_name = new_ext.name().clone();
let utou = Signature::new_endo(vec![USIZE_T]);
let mk_op = |s| OpaqueOp::new(unknown_ext.clone(), s, String::new(), vec![], utou.clone());
let mut mk_op = |s| new_ext.simple_ext_op(s, utou.clone());
let mut h = DFGBuilder::new(
Signature::new(type_row![USIZE_T, BOOL_T], type_row![USIZE_T])
.with_extension_delta(unknown_ext.clone()),
.with_extension_delta(ext_name.clone()),
)
.unwrap();
let [i, b] = h.input_wires_arr();
Expand All @@ -653,7 +654,7 @@ mod test {
(vec![type_row![]; 2], b),
[(USIZE_T, i)],
type_row![USIZE_T],
unknown_ext.clone(),
ext_name.clone(),
)
.unwrap();
let mut case1 = cond.case_builder(0).unwrap();
Expand All @@ -667,7 +668,7 @@ mod test {
.unwrap();
let mut baz_dfg = case2
.dfg_builder(
utou.clone().with_extension_delta(unknown_ext.clone()),
utou.clone().with_extension_delta(ext_name.clone()),
bar.outputs(),
)
.unwrap();
Expand Down
34 changes: 31 additions & 3 deletions hugr-core/src/hugr/serialize/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ use crate::extension::prelude::{BOOL_T, PRELUDE_ID, QB_T, USIZE_T};
use crate::extension::simple_op::MakeRegisteredOp;
use crate::extension::{test::SimpleOpDef, ExtensionSet, EMPTY_REG, PRELUDE_REGISTRY};
use crate::hugr::internal::HugrMutInternals;
use crate::ops::custom::{ExtensionOp, OpaqueOp};
use crate::hugr::validate::ValidationError;
use crate::ops::custom::{ExtensionOp, OpaqueOp, OpaqueOpError};
use crate::ops::{self, dataflow::IOTrait, Input, Module, Noop, Output, Value, DFG};
use crate::std_extensions::arithmetic::float_types::FLOAT64_TYPE;
use crate::std_extensions::arithmetic::int_ops::INT_OPS_REGISTRY;
Expand Down Expand Up @@ -338,6 +339,25 @@ fn dfg_roundtrip() -> Result<(), Box<dyn std::error::Error>> {
Ok(())
}

#[test]
fn extension_ops() -> Result<(), Box<dyn std::error::Error>> {
let tp: Vec<Type> = vec![BOOL_T; 1];
let mut dfg = DFGBuilder::new(endo_sig(tp))?;
let [wire] = dfg.input_wires_arr();

// Add an extension operation
let extension_op: ExtensionOp = NotOp.to_extension_op().unwrap();
let wire = dfg
.add_dataflow_op(extension_op.clone(), [wire])
.unwrap()
.out_wire(0);

let hugr = dfg.finish_hugr_with_outputs([wire], &PRELUDE_REGISTRY)?;

check_hugr_roundtrip(&hugr, true);
Ok(())
}

#[test]
fn opaque_ops() -> Result<(), Box<dyn std::error::Error>> {
let tp: Vec<Type> = vec![BOOL_T; 1];
Expand All @@ -353,11 +373,19 @@ fn opaque_ops() -> Result<(), Box<dyn std::error::Error>> {

// Add an unresolved opaque operation
let opaque_op: OpaqueOp = extension_op.into();
let ext_name = opaque_op.extension().clone();
let wire = dfg.add_dataflow_op(opaque_op, [wire]).unwrap().out_wire(0);

let hugr = dfg.finish_hugr_with_outputs([wire], &PRELUDE_REGISTRY)?;
assert_eq!(
dfg.finish_hugr_with_outputs([wire], &PRELUDE_REGISTRY),
Err(ValidationError::OpaqueOpError(OpaqueOpError::UnresolvedOp(
wire.node(),
"Not".into(),
ext_name
))
.into())
);

check_hugr_roundtrip(&hugr, true);
Ok(())
}

Expand Down
25 changes: 13 additions & 12 deletions hugr-core/src/hugr/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use thiserror::Error;
use crate::extension::{ExtensionRegistry, SignatureError, TO_BE_INFERRED};

use crate::ops::constant::ConstTypeError;
use crate::ops::custom::{resolve_opaque_op, ExtensionOp, OpaqueOpError};
use crate::ops::custom::{ExtensionOp, OpaqueOpError};
use crate::ops::validate::{ChildrenEdgeData, ChildrenValidationError, EdgeValidationError};
use crate::ops::{FuncDefn, OpParent, OpTag, OpTrait, OpType, ValidateOp};
use crate::types::type_param::TypeParam;
Expand Down Expand Up @@ -158,6 +158,13 @@ impl<'a, 'b> ValidationContext<'a, 'b> {
fn validate_node(&self, node: Node) -> Result<(), ValidationError> {
let op_type = self.hugr.get_optype(node);

if let OpType::OpaqueOp(opaque) = op_type {
Err(OpaqueOpError::UnresolvedOp(
node,
opaque.op_name().clone(),
opaque.extension().clone(),
))?;
}
// The Hugr can have only one root node.
if node == self.hugr.root() {
// The root node has no edges.
Expand Down Expand Up @@ -577,17 +584,11 @@ impl<'a, 'b> ValidationContext<'a, 'b> {
match op_type {
OpType::ExtensionOp(ext_op) => validate_ext(ext_op)?,
OpType::OpaqueOp(opaque) => {
// Try to resolve serialized names to actual OpDefs in Extensions.
if let Some(ext_op) = resolve_opaque_op(node, opaque, self.extension_registry)? {
validate_ext(&ext_op)?;
} else {
// Best effort. Just check TypeArgs are valid in themselves, allowing any of them
// to contain type vars (we don't know how many are binary params, so accept if in doubt)
for arg in opaque.args() {
arg.validate(self.extension_registry, var_decls)
.map_err(|cause| ValidationError::SignatureError { node, cause })?;
}
}
Err(OpaqueOpError::UnresolvedOp(
node,
opaque.op_name().clone(),
opaque.extension().clone(),
))?;
}
OpType::Call(c) => {
c.validate(self.extension_registry)
Expand Down
115 changes: 101 additions & 14 deletions hugr-core/src/ops/custom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,31 @@ impl ExtensionOp {
})
}

/// If OpDef is missing binary computation, trust the cached signature.
fn new_with_cached(
def: Arc<OpDef>,
args: impl Into<Vec<TypeArg>>,
opaque: &OpaqueOp,
exts: &ExtensionRegistry,
) -> Result<Self, SignatureError> {
let args = args.into();
// TODO skip computation depending on config
// see https://github.com/CQCL/hugr/issues/1363
let signature = match def.compute_signature(&args, exts) {
Ok(sig) => sig,
Err(SignatureError::MissingComputeFunc) => {
// TODO raise warning: https://github.com/CQCL/hugr/issues/1432
opaque.signature()
}
Err(e) => return Err(e),
};
Ok(Self {
def,
args,
signature,
})
}

/// Return the argument values for this operation.
pub fn args(&self) -> &[TypeArg] {
&self.args
Expand Down Expand Up @@ -224,9 +249,8 @@ pub fn resolve_extension_ops(
let mut replacements = Vec::new();
for n in h.nodes() {
if let OpType::OpaqueOp(opaque) = h.get_optype(n) {
if let Some(resolved) = resolve_opaque_op(n, opaque, extension_registry)? {
replacements.push((n, resolved))
}
let resolved = resolve_opaque_op(n, opaque, extension_registry)?;
replacements.push((n, resolved));
}
}
// Only now can we perform the replacements as the 'for' loop was borrowing 'h' preventing use from using it mutably
Expand All @@ -252,7 +276,7 @@ pub fn resolve_opaque_op(
node: Node,
opaque: &OpaqueOp,
extension_registry: &ExtensionRegistry,
) -> Result<Option<ExtensionOp>, OpaqueOpError> {
) -> Result<ExtensionOp, OpaqueOpError> {
if let Some(r) = extension_registry.get(&opaque.extension) {
// Fail if the Extension was found but did not have the expected operation
let Some(def) = r.get_op(&opaque.name) else {
Expand All @@ -262,12 +286,17 @@ pub fn resolve_opaque_op(
r.name().clone(),
));
};
let ext_op = ExtensionOp::new(def.clone(), opaque.args.clone(), extension_registry)
.map_err(|e| OpaqueOpError::SignatureError {
node,
name: opaque.name.clone(),
cause: e,
})?;
let ext_op = ExtensionOp::new_with_cached(
def.clone(),
opaque.args.clone(),
opaque,
extension_registry,
)
.map_err(|e| OpaqueOpError::SignatureError {
node,
name: opaque.name.clone(),
cause: e,
})?;
if opaque.signature() != ext_op.signature() {
return Err(OpaqueOpError::SignatureMismatch {
node,
Expand All @@ -277,9 +306,13 @@ pub fn resolve_opaque_op(
stored: opaque.signature.clone(),
});
};
Ok(Some(ext_op))
Ok(ext_op)
} else {
Ok(None)
Err(OpaqueOpError::UnresolvedOp(
node,
opaque.name.clone(),
opaque.extension.clone(),
))
}
}

Expand Down Expand Up @@ -310,17 +343,25 @@ pub enum OpaqueOpError {
#[source]
cause: SignatureError,
},
/// Unresolved operation encountered during validation.
#[error("Unexpected unresolved opaque operation '{1}' in {0}, from Extension {2}.")]
UnresolvedOp(Node, OpName, ExtensionId),
}

#[cfg(test)]
mod test {

use crate::{
extension::prelude::{BOOL_T, QB_T, USIZE_T},
extension::{
prelude::{BOOL_T, QB_T, USIZE_T},
SignatureFunc,
},
std_extensions::arithmetic::{
int_ops::{self, INT_OPS_REGISTRY},
int_types::INT_TYPES,
},
types::FuncValueType,
Extension,
};

use super::*;
Expand Down Expand Up @@ -357,8 +398,54 @@ mod test {
);
let resolved =
super::resolve_opaque_op(Node::from(portgraph::NodeIndex::new(1)), &opaque, registry)
.unwrap()
.unwrap();
assert_eq!(resolved.def().name(), "itobool");
}

#[test]
fn resolve_missing() {
let mut ext = Extension::new_test("ext".try_into().unwrap());
let ext_id = ext.name().clone();
let val_name = "missing_val";
let comp_name = "missing_comp";

let endo_sig = Signature::new_endo(BOOL_T);
ext.add_op(
val_name.into(),
"".to_string(),
SignatureFunc::MissingValidateFunc(FuncValueType::from(endo_sig.clone()).into()),
)
.unwrap();

ext.add_op(
comp_name.into(),
"".to_string(),
SignatureFunc::MissingComputeFunc,
)
.unwrap();
let registry = ExtensionRegistry::try_new([ext]).unwrap();
let opaque_val = OpaqueOp::new(
ext_id.clone(),
val_name,
"".into(),
vec![],
endo_sig.clone(),
);
let opaque_comp = OpaqueOp::new(ext_id.clone(), comp_name, "".into(), vec![], endo_sig);
let resolved_val = super::resolve_opaque_op(
Node::from(portgraph::NodeIndex::new(1)),
&opaque_val,
&registry,
)
.unwrap();
assert_eq!(resolved_val.def().name(), val_name);

let resolved_comp = super::resolve_opaque_op(
Node::from(portgraph::NodeIndex::new(2)),
&opaque_comp,
&registry,
)
.unwrap();
assert_eq!(resolved_comp.def().name(), comp_name);
}
}
Loading

0 comments on commit fbbb805

Please sign in to comment.