Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: Serialised extensions #1371

Merged
merged 19 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions hugr-core/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,16 @@ pub enum SignatureError {
cached: Signature,
expected: Signature,
},

/// Extension declaration specifies a binary compute signature function, but none
/// was loaded.
#[error("Binary compute signature function not loaded.")]
MissingComputeFunc,

/// Extension declaration specifies a binary compute signature function, but none
/// was loaded.
#[error("Binary validate signature function not loaded.")]
MissingValidateFunc,
}

/// Concrete instantiations of types and operations defined in extensions.
Expand Down
8 changes: 6 additions & 2 deletions hugr-core/src/extension/declarative/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,12 @@ enum TypeDefBoundDeclaration {
impl From<TypeDefBoundDeclaration> for TypeDefBound {
fn from(bound: TypeDefBoundDeclaration) -> Self {
match bound {
TypeDefBoundDeclaration::Copyable => Self::Explicit(TypeBound::Copyable),
TypeDefBoundDeclaration::Any => Self::Explicit(TypeBound::Any),
TypeDefBoundDeclaration::Copyable => Self::Explicit {
bound: TypeBound::Copyable,
},
TypeDefBoundDeclaration::Any => Self::Explicit {
bound: TypeBound::Any,
},
}
}
}
Expand Down
107 changes: 62 additions & 45 deletions hugr-core/src/extension/op_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use crate::ops::{OpName, OpNameRef};
use crate::types::type_param::{check_type_args, TypeArg, TypeParam};
use crate::types::{FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature};
use crate::Hugr;
mod serialize_signature_func;

/// Trait necessary for binary computations of OpDef signature
pub trait CustomSignatureFunc: Send + Sync {
Expand Down Expand Up @@ -114,29 +115,19 @@ pub trait CustomLowerFunc: Send + Sync {
) -> Option<Hugr>;
}

/// Encode a signature as `PolyFuncTypeRV` but optionally allow validating type
/// Encode a signature as `PolyFuncTypeRV` but allow validating type
ss2165 marked this conversation as resolved.
Show resolved Hide resolved
/// arguments via a custom binary. The binary cannot be serialized so will be
/// lost over a serialization round-trip.
#[derive(serde::Deserialize, serde::Serialize)]
pub struct CustomValidator {
#[serde(flatten)]
poly_func: PolyFuncTypeRV,
#[serde(skip)]
/// Custom function for validating type arguments before returning the signature.
pub(crate) validate: Box<dyn ValidateTypeArgs>,
}

impl CustomValidator {
/// Encode a signature using a `PolyFuncTypeRV`
pub fn from_polyfunc(poly_func: impl Into<PolyFuncTypeRV>) -> Self {
Self {
poly_func: poly_func.into(),
validate: Default::default(),
}
}

/// Encode a signature using a `PolyFuncTypeRV`, with a custom function for
/// validating type arguments before returning the signature.
pub fn new_with_validator(
pub fn new(
poly_func: impl Into<PolyFuncTypeRV>,
validate: impl ValidateTypeArgs + 'static,
) -> Self {
Expand All @@ -147,21 +138,22 @@ impl CustomValidator {
}
}

/// The two ways in which an OpDef may compute the Signature of each operation node.
#[derive(serde::Deserialize, serde::Serialize)]
/// The ways in which an OpDef may compute the Signature of each operation node.
pub enum SignatureFunc {
// Note: except for serialization, we could have type schemes just implement the same
// CustomSignatureFunc trait too, and replace this enum with Box<dyn CustomSignatureFunc>.
// However instead we treat all CustomFunc's as non-serializable.
/// A PolyFuncType (polymorphic function type), with optional custom
/// validation for provided type arguments,
#[serde(rename = "signature")]
PolyFuncType(CustomValidator),
#[serde(skip)]
/// An explicit polymorphic function type.
PolyFuncType(PolyFuncTypeRV),
/// A polymorphic function type with a custom binary for validating type arguments.
Copy link
Contributor

@acl-cqc acl-cqc Jul 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we could slightly emphasize that this is mostly the same as the previous case but PLUS a custom binary

CustomValidator(CustomValidator),
/// Serialized declaration specified a custom validate binary but it was not provided.
MissingValidateFunc(PolyFuncTypeRV),
/// A custom binary which computes a polymorphic function type given values
/// for its static type parameters.
CustomFunc(Box<dyn CustomSignatureFunc>),
/// Serialized declaration specified a custom compute binary but it was not provided.
MissingComputeFunc,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about a missing custom validation func? In either case I think we just have to trust the cache, maybe with a warning, so I guess if the warning is the same then we don't have to distinguish, is that the plan?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure - in the case of missing validation function do you trust the cached signature or use the type scheme to generate the signature and check against cache as you would without custom validation?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Given the validation function can only say one of two things - invalid, or use the type scheme - you can try the latter (which might reject if TypeArgs don't match the TypeParams, say) and see if that matches; that might get you an error, but even if the typescheme says ok, if there's a binary validation function that you haven't got, then that still has to be a warning

}

#[derive(PartialEq, Eq, Debug)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Methinks we should not need NoValidate. One should use Signature::PolyFuncType instead.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, moved it just to the test that used it

struct NoValidate;
impl ValidateTypeArgs for NoValidate {
fn validate<'o, 'a: 'o>(
Expand All @@ -188,39 +180,50 @@ impl<T: CustomSignatureFunc + 'static> From<T> for SignatureFunc {

impl From<PolyFuncType> for SignatureFunc {
fn from(value: PolyFuncType) -> Self {
Self::PolyFuncType(CustomValidator::from_polyfunc(value))
Self::PolyFuncType(value.into())
}
}

impl From<PolyFuncTypeRV> for SignatureFunc {
fn from(v: PolyFuncTypeRV) -> Self {
Self::PolyFuncType(CustomValidator::from_polyfunc(v))
Self::PolyFuncType(v)
}
}

impl From<FuncValueType> for SignatureFunc {
fn from(v: FuncValueType) -> Self {
Self::PolyFuncType(CustomValidator::from_polyfunc(v))
Self::PolyFuncType(v.into())
}
}

impl From<Signature> for SignatureFunc {
fn from(v: Signature) -> Self {
Self::PolyFuncType(CustomValidator::from_polyfunc(FuncValueType::from(v)))
Self::PolyFuncType(FuncValueType::from(v).into())
}
}

impl From<CustomValidator> for SignatureFunc {
fn from(v: CustomValidator) -> Self {
Self::PolyFuncType(v)
Self::CustomValidator(v)
}
}

impl SignatureFunc {
fn static_params(&self) -> &[TypeParam] {
match self {
SignatureFunc::PolyFuncType(ts) => ts.poly_func.params(),
fn static_params(&self) -> Result<&[TypeParam], SignatureError> {
Ok(match self {
SignatureFunc::PolyFuncType(ts)
| SignatureFunc::CustomValidator(CustomValidator { poly_func: ts, .. })
| SignatureFunc::MissingValidateFunc(ts) => ts.params(),
SignatureFunc::CustomFunc(func) => func.static_params(),
SignatureFunc::MissingComputeFunc => return Err(SignatureError::MissingComputeFunc),
acl-cqc marked this conversation as resolved.
Show resolved Hide resolved
})
}

/// If the signature is missing a custom validation function, ignore and treat as
/// self-contained type scheme (with no custom validation).
pub fn ignore_missing_validation(&mut self) {
if let SignatureFunc::MissingValidateFunc(ts) = self {
*self = SignatureFunc::PolyFuncType(ts.clone());
}
}

Expand All @@ -243,10 +246,11 @@ impl SignatureFunc {
) -> Result<Signature, SignatureError> {
let temp: PolyFuncTypeRV; // to keep alive
let (pf, args) = match &self {
SignatureFunc::PolyFuncType(custom) => {
custom.validate.validate(args, def, exts)?;
SignatureFunc::CustomValidator(custom) => {
custom.validate.as_ref().validate(args, def, exts)?;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

presumably this extra as_ref() is needed? Code looks very similar to what we had before.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wasn't needed, removed

(&custom.poly_func, args)
}
SignatureFunc::PolyFuncType(ts) => (ts, args),
SignatureFunc::CustomFunc(func) => {
let static_params = func.static_params();
let (static_args, other_args) = args.split_at(min(static_params.len(), args.len()));
Expand All @@ -255,6 +259,10 @@ impl SignatureFunc {
temp = func.compute_signature(static_args, def, exts)?;
(&temp, other_args)
}
SignatureFunc::MissingComputeFunc => return Err(SignatureError::MissingComputeFunc),
SignatureFunc::MissingValidateFunc(_) => {
return Err(SignatureError::MissingValidateFunc)
}
};

let mut res = pf.instantiate(args, exts)?;
Expand All @@ -268,8 +276,11 @@ impl SignatureFunc {
impl Debug for SignatureFunc {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Self::PolyFuncType(ts) => ts.poly_func.fmt(f),
Self::CustomValidator(ts) => ts.poly_func.fmt(f),
Self::PolyFuncType(ts) => ts.fmt(f),
Self::CustomFunc { .. } => f.write_str("<custom sig>"),
Self::MissingComputeFunc => f.write_str("<missing custom sig>"),
Self::MissingValidateFunc(_) => f.write_str("<missing custom validation>"),
}
}
}
Expand Down Expand Up @@ -321,10 +332,11 @@ pub struct OpDef {
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
misc: HashMap<String, serde_json::Value>,

#[serde(flatten)]
#[serde(with = "serialize_signature_func", flatten)]
signature_func: SignatureFunc,
// Some operations cannot lower themselves and tools that do not understand them
// can only treat them as opaque/black-box ops.
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub(crate) lower_funcs: Vec<LowerFunc>,

/// Operations can optionally implement [`ConstFold`] to implement constant folding.
Expand All @@ -344,7 +356,8 @@ impl OpDef {
) -> Result<(), SignatureError> {
let temp: PolyFuncTypeRV; // to keep alive
let (pf, args) = match &self.signature_func {
SignatureFunc::PolyFuncType(ts) => (&ts.poly_func, args),
SignatureFunc::CustomValidator(ts) => (&ts.poly_func, args),
SignatureFunc::PolyFuncType(ts) => (ts, args),
SignatureFunc::CustomFunc(custom) => {
let (static_args, other_args) =
args.split_at(min(custom.static_params().len(), args.len()));
Expand All @@ -355,6 +368,10 @@ impl OpDef {
temp = custom.compute_signature(static_args, self, exts)?;
(&temp, other_args)
}
SignatureFunc::MissingComputeFunc => return Err(SignatureError::MissingComputeFunc),
SignatureFunc::MissingValidateFunc(_) => {
return Err(SignatureError::MissingValidateFunc)
}
};
args.iter()
.try_for_each(|ta| ta.validate(exts, var_decls))?;
Expand Down Expand Up @@ -409,14 +426,14 @@ impl OpDef {
}

/// Returns a reference to the params of this [`OpDef`].
pub fn params(&self) -> &[TypeParam] {
pub fn params(&self) -> Result<&[TypeParam], SignatureError> {
self.signature_func.static_params()
}

pub(super) fn validate(&self, exts: &ExtensionRegistry) -> Result<(), SignatureError> {
// TODO https://github.com/CQCL/hugr/issues/624 validate declared TypeParams
// for both type scheme and custom binary
if let SignatureFunc::PolyFuncType(ts) = &self.signature_func {
if let SignatureFunc::CustomValidator(ts) = &self.signature_func {
// The type scheme may contain row variables so be of variable length;
// these will have to be substituted to fixed-length concrete types when
// the OpDef is instantiated into an actual OpType.
Expand Down Expand Up @@ -557,12 +574,14 @@ pub(super) mod test {
// a compile error here. To fix: modify the fields matched on here,
// maintaining the lack of `..` and, for each part that is
// serializable, ensure we are checking it for equality below.
SignatureFunc::PolyFuncType(CustomValidator {
SignatureFunc::CustomValidator(CustomValidator {
poly_func,
validate: _,
}) => Some(poly_func.clone()),
})
| SignatureFunc::PolyFuncType(poly_func)
| SignatureFunc::MissingValidateFunc(poly_func) => Some(poly_func.clone()),
// This is ruled out by `new()` but leave it here for later.
SignatureFunc::CustomFunc(_) => None,
SignatureFunc::CustomFunc(_) | SignatureFunc::MissingComputeFunc => None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect the comment on the previous line can go (I suspect "later" meant "when we have implemented serialization", essentially!), please clarify if I'm wrong

};

let get_lower_funcs = |lfs: &Vec<LowerFunc>| {
Expand Down Expand Up @@ -787,9 +806,7 @@ pub(super) mod test {

use crate::{
builder::test::simple_dfg_hugr,
extension::{
op_def::LowerFunc, CustomValidator, ExtensionId, ExtensionSet, OpDef, SignatureFunc,
},
extension::{op_def::LowerFunc, ExtensionId, ExtensionSet, OpDef, SignatureFunc},
types::PolyFuncTypeRV,
};

Expand All @@ -801,7 +818,7 @@ pub(super) mod test {
// this is not serialized. When it is, we should generate
// examples here .
any::<PolyFuncTypeRV>()
.prop_map(|x| SignatureFunc::PolyFuncType(CustomValidator::from_polyfunc(x)))
.prop_map(SignatureFunc::PolyFuncType)
.boxed()
}
}
Expand Down
Loading
Loading