Skip to content

Commit

Permalink
feat: extract serialisation in to module and add test
Browse files Browse the repository at this point in the history
  • Loading branch information
ss2165 committed Jul 29, 2024
1 parent 0353aa8 commit 6573761
Show file tree
Hide file tree
Showing 4 changed files with 208 additions and 59 deletions.
60 changes: 1 addition & 59 deletions hugr-core/src/extension/op_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use crate::ops::{OpName, OpNameRef};
use crate::types::type_param::{check_type_args, TypeArg, TypeParam};
use crate::types::{FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature};
use crate::Hugr;
mod serialize_signature_func;

/// Trait necessary for binary computations of OpDef signature
pub trait CustomSignatureFunc: Send + Sync {
Expand Down Expand Up @@ -163,65 +164,6 @@ pub enum SignatureFunc {
MissingComputeFunc,
}

mod serialize_signature_func {
use serde::{Deserialize, Serialize};

use super::{PolyFuncTypeRV, SignatureFunc};
#[derive(serde::Deserialize, serde::Serialize)]
struct SerSignatureFunc {
/// If the type scheme is available explicitly, store it.
signature: Option<PolyFuncTypeRV>,
/// Whether an associated binary function is expected.
/// If `signature` is `None`, a true value here indicates a custom compute function.
/// If `signature` is not `None`, a true value here indicates a custom validation function.
binary: bool,
}

pub(super) fn serialize<S>(
value: &super::SignatureFunc,
serializer: S,
) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
match value {
SignatureFunc::PolyFuncType(custom) => SerSignatureFunc {
signature: Some(custom.poly_func.clone()),
binary: custom.validate.is_some(),
},
SignatureFunc::MissingValidateFunc(poly_func) => SerSignatureFunc {
signature: Some(poly_func.clone()),
binary: true,
},
SignatureFunc::CustomFunc(_) => SerSignatureFunc {
signature: None,
binary: true,
},
SignatureFunc::MissingComputeFunc => SerSignatureFunc {
signature: None,
binary: false,
},
}
.serialize(serializer)
}

pub(super) fn deserialize<'de, D>(deserializer: D) -> Result<super::SignatureFunc, D::Error>
where
D: serde::Deserializer<'de>,
{
let SerSignatureFunc { signature, binary } = SerSignatureFunc::deserialize(deserializer)?;

match (signature, binary) {
(Some(sig), false) => Ok(sig.into()),
(Some(sig), true) => Ok(SignatureFunc::MissingValidateFunc(sig)),
(None, true) => Ok(SignatureFunc::MissingComputeFunc),
(None, false) => Err(serde::de::Error::custom(
"No signature provided and custom computation not expected.",
)),
}
}
}

#[derive(PartialEq, Eq, Debug)]
struct NoValidate;
impl ValidateTypeArgs for NoValidate {
Expand Down
178 changes: 178 additions & 0 deletions hugr-core/src/extension/op_def/serialize_signature_func.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
use serde::{Deserialize, Serialize};

use super::{PolyFuncTypeRV, SignatureFunc};
#[derive(serde::Deserialize, serde::Serialize, PartialEq, Debug, Clone)]
struct SerSignatureFunc {
/// If the type scheme is available explicitly, store it.
signature: Option<PolyFuncTypeRV>,
/// Whether an associated binary function is expected.
/// If `signature` is `None`, a true value here indicates a custom compute function.
/// If `signature` is not `None`, a true value here indicates a custom validation function.
binary: bool,
}

pub(super) fn serialize<S>(value: &super::SignatureFunc, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
match value {
SignatureFunc::PolyFuncType(custom) => SerSignatureFunc {
signature: Some(custom.poly_func.clone()),
binary: custom.validate.is_some(),
},
SignatureFunc::MissingValidateFunc(poly_func) => SerSignatureFunc {
signature: Some(poly_func.clone()),
binary: true,
},
SignatureFunc::CustomFunc(_) | SignatureFunc::MissingComputeFunc => SerSignatureFunc {
signature: None,
binary: true,
},
}
.serialize(serializer)
}

pub(super) fn deserialize<'de, D>(deserializer: D) -> Result<super::SignatureFunc, D::Error>
where
D: serde::Deserializer<'de>,
{
let SerSignatureFunc { signature, binary } = SerSignatureFunc::deserialize(deserializer)?;

match (signature, binary) {
(Some(sig), false) => Ok(sig.into()),
(Some(sig), true) => Ok(SignatureFunc::MissingValidateFunc(sig)),
(None, true) => Ok(SignatureFunc::MissingComputeFunc),
(None, false) => Err(serde::de::Error::custom(
"No signature provided and custom computation not expected.",
)),
}
}
#[derive(serde::Deserialize, serde::Serialize, Debug)]
/// Wrapper we can derive serde for, to allow round-trip serialization
struct Wrapper {
#[serde(
serialize_with = "serialize",
deserialize_with = "deserialize",
flatten
)]
inner: SignatureFunc,
}

#[cfg(test)]
mod test {
use cool_asserts::assert_matches;
use serde::de::Error;

use super::*;
use crate::{
extension::{op_def::NoValidate, prelude::USIZE_T, CustomSignatureFunc, CustomValidator},
types::{FuncValueType, Signature, TypeArg},
};
// Define test-only conversions via serialization roundtrip
impl TryFrom<SerSignatureFunc> for SignatureFunc {
type Error = serde_json::Error;
fn try_from(value: SerSignatureFunc) -> Result<Self, Self::Error> {
let ser = serde_json::to_value(value).unwrap();
let w: Wrapper = serde_json::from_value(ser)?;
Ok(w.inner)
}
}

impl From<SignatureFunc> for SerSignatureFunc {
fn from(value: SignatureFunc) -> Self {
let ser = serde_json::to_value(Wrapper { inner: value }).unwrap();
serde_json::from_value(ser).unwrap()
}
}
struct CustomSig;

impl CustomSignatureFunc for CustomSig {
fn compute_signature<'o, 'a: 'o>(
&'a self,
_arg_values: &[TypeArg],
_def: &'o crate::extension::op_def::OpDef,
_extension_registry: &crate::extension::ExtensionRegistry,
) -> Result<crate::types::PolyFuncTypeRV, crate::extension::SignatureError> {
Ok(Default::default())
}

fn static_params(&self) -> &[crate::types::type_param::TypeParam] {
&[]
}
}
#[test]
fn test_serial_sig_func() {
// test round-trip
let sig: FuncValueType = Signature::new_endo(USIZE_T.clone()).into();
let simple: SignatureFunc = sig.clone().into();
let ser: SerSignatureFunc = simple.into();
let expected_ser = SerSignatureFunc {
signature: Some(sig.clone().into()),
binary: false,
};

assert_eq!(ser, expected_ser);
let deser = SignatureFunc::try_from(ser).unwrap();
assert_matches!( deser,
SignatureFunc::PolyFuncType(CustomValidator {
poly_func,
validate,
}) => {
assert_eq!(poly_func, sig.clone().into());
assert!(validate.is_none());
});

let with_custom: SignatureFunc =
CustomValidator::new_with_validator(sig.clone(), NoValidate).into();
let ser: SerSignatureFunc = with_custom.into();
let expected_ser = SerSignatureFunc {
signature: Some(sig.clone().into()),
binary: true,
};
assert_eq!(ser, expected_ser);
let deser = SignatureFunc::try_from(ser.clone()).unwrap();
assert_matches!(&deser,
SignatureFunc::MissingValidateFunc(poly_func) => {
assert_eq!(poly_func, &PolyFuncTypeRV::from(sig.clone()));
}
);

// re-serializing should give the same result
assert_eq!(
SerSignatureFunc::from(SignatureFunc::try_from(ser).unwrap()),
expected_ser
);

let deser_ignored = deser.ignore_missing_validation();
assert_matches!(
&deser_ignored,
&SignatureFunc::PolyFuncType(CustomValidator { validate: None, .. })
);

let custom: SignatureFunc = CustomSig.into();
let ser: SerSignatureFunc = custom.into();
let expected_ser = SerSignatureFunc {
signature: None,
binary: true,
};
assert_eq!(ser, expected_ser);

let deser = SignatureFunc::try_from(ser).unwrap();
assert_matches!(&deser, &SignatureFunc::MissingComputeFunc);

assert_eq!(SerSignatureFunc::from(deser), expected_ser);

let bad_ser = SerSignatureFunc {
signature: None,
binary: false,
};

let err = SignatureFunc::try_from(bad_ser).unwrap_err();

assert_eq!(
err.to_string(),
serde_json::Error::custom("No signature provided and custom computation not expected.")
.to_string()
);
}
}
12 changes: 12 additions & 0 deletions hugr-core/src/hugr/serialize/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,18 @@ fn roundtrip_optype(#[case] optype: impl Into<OpType> + std::fmt::Debug) {
});
}

#[test]
// test all standard extension serialisations are valid against scheme
fn std_extensions_valid() {
let std_reg = crate::std_extensions::std_reg();
for (_, ext) in std_reg.into_iter() {
let val = serde_json::to_value(ext).unwrap();
NamedSchema::check_schemas(&val, get_schemas(true));
// check deserialises correctly, can't check equality because of custom binaries.
let _: crate::extension::Extension = serde_json::from_value(val).unwrap();
}
}

mod proptest {
use super::check_testing_roundtrip;
use super::{NodeSer, SimpleOpDef};
Expand Down
17 changes: 17 additions & 0 deletions hugr-core/src/std_extensions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,24 @@
//!
//! These may be moved to other crates in the future, or dropped altogether.

use crate::extension::ExtensionRegistry;

pub mod arithmetic;
pub mod collections;
pub mod logic;
pub mod ptr;

/// Extension registry with all standard extensions and prelude.
pub fn std_reg() -> ExtensionRegistry {
ExtensionRegistry::try_new([
crate::extension::prelude::PRELUDE.to_owned(),
arithmetic::int_ops::EXTENSION.to_owned(),
arithmetic::int_types::EXTENSION.to_owned(),
arithmetic::conversions::EXTENSION.to_owned(),
arithmetic::float_ops::EXTENSION.to_owned(),
arithmetic::float_types::EXTENSION.to_owned(),
logic::EXTENSION.to_owned(),
ptr::EXTENSION.to_owned(),
])
.unwrap()
}

0 comments on commit 6573761

Please sign in to comment.