Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: Extension requires a version #1367

Merged
merged 6 commits into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions hugr-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ delegate = { workspace = true }
paste = { workspace = true }
strum = { workspace = true }
strum_macros = { workspace = true }
semver = { version = "1.0.23", features = ["serde"] }

[dev-dependencies]
rstest = { workspace = true }
Expand Down
118 changes: 100 additions & 18 deletions hugr-core/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
//! TODO: YAML declaration and parsing. This should be similar to a plugin
//! system (outside the `types` module), which also parses nested [`OpDef`]s.

pub use semver::Version;
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pub so downstream crates don't need to add a semver dependency

use std::collections::btree_map;
use std::collections::hash_map;
use std::collections::{BTreeMap, BTreeSet, HashMap};
Expand Down Expand Up @@ -55,21 +56,17 @@ impl ExtensionRegistry {
pub fn try_new(
value: impl IntoIterator<Item = Extension>,
) -> Result<Self, ExtensionRegistryError> {
let mut exts = BTreeMap::new();
let mut res = ExtensionRegistry(BTreeMap::new());

for ext in value.into_iter() {
let prev = exts.insert(ext.name.clone(), ext);
if let Some(prev) = prev {
return Err(ExtensionRegistryError::AlreadyRegistered(
prev.name().clone(),
));
};
res.register(ext)?;
}

// Note this potentially asks extensions to validate themselves against other extensions that
// may *not* be valid themselves yet. It'd be better to order these respecting dependencies,
// or at least to validate the types first - which we don't do at all yet:
// TODO https://github.com/CQCL/hugr/issues/624. However, parametrized types could be
// cyclically dependent, so there is no perfect solution, and this is at least simple.
let res = ExtensionRegistry(exts);
for ext in res.0.values() {
ext.validate(&res)
.map_err(|e| ExtensionRegistryError::InvalidSignature(ext.name().clone(), e))?;
Expand All @@ -82,13 +79,35 @@ impl ExtensionRegistry {
/// Returns a reference to the registered extension if successful.
pub fn register(&mut self, extension: Extension) -> Result<&Extension, ExtensionRegistryError> {
match self.0.entry(extension.name().clone()) {
btree_map::Entry::Occupied(_) => Err(ExtensionRegistryError::AlreadyRegistered(
btree_map::Entry::Occupied(prev) => Err(ExtensionRegistryError::AlreadyRegistered(
extension.name().clone(),
prev.get().version().clone(),
extension.version().clone(),
)),
btree_map::Entry::Vacant(ve) => Ok(ve.insert(extension)),
}
}

/// Registers a new extension to the registry, keeping most up to date if extension exists.
///
/// If extension IDs match, the extension with the higher version is kept.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I note we are not really doing anything with Semver here (it could be any totally-ordered thing), and I wonder whether this should be the place. Replacing an extension with a breaking-change-upgrade could be problematic; should we

  • have a boolean flag here, allow breaking-change upgrades
  • keep multiple major versions of the extension (a single minor version of each)? I think this will be too difficult because then we'd have to version every reference to an extension.
  • Error on any breaking-change upgrade
  • Do something at Hugr validation time. In which case maybe just the checks we have are sufficient, so maybe the right answer is to do nothing, i.e. leave PR as it stands....

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I think I prefer to keep the logic here simple and flexible and deal with valid
updates etc. elsewhere.

Maybe I should add a "remove" function so the registry is actually fully flexible, and
this is just a utility for update-only changes?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added remove

/// If versions match, the original extension is kept.
/// Returns a reference to the registered extension if successful.
pub fn register_updated(
&mut self,
extension: Extension,
) -> Result<&Extension, ExtensionRegistryError> {
match self.0.entry(extension.name().clone()) {
btree_map::Entry::Occupied(mut prev) => {
if prev.get().version() < extension.version() {
*prev.get_mut() = extension;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be worth a note that ideally we would check the extensions were identical but this is not possible (no Eq)

Ok(prev.into_mut())
}
btree_map::Entry::Vacant(ve) => Ok(ve.insert(extension)),
}
}

/// Returns the number of extensions in the registry.
pub fn len(&self) -> usize {
self.0.len()
Expand All @@ -103,6 +122,11 @@ impl ExtensionRegistry {
pub fn iter(&self) -> impl Iterator<Item = (&ExtensionId, &Extension)> {
self.0.iter()
}

/// Delete an extension from the registry and return it if it was present.
pub fn remove_extension(&mut self, name: &ExtensionId) -> Option<Extension> {
self.0.remove(name)
}
}

impl IntoIterator for ExtensionRegistry {
Expand Down Expand Up @@ -264,6 +288,8 @@ pub type ExtensionId = IdentList;
/// A extension is a set of capabilities required to execute a graph.
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct Extension {
/// Extension version, follows semver.
pub version: Version,
/// Unique identifier for the extension.
pub name: ExtensionId,
/// Other extensions defining types used by this extension.
Expand All @@ -286,21 +312,25 @@ pub struct Extension {

impl Extension {
/// Creates a new extension with the given name.
pub fn new(name: ExtensionId) -> Self {
Self::new_with_reqs(name, ExtensionSet::default())
}

/// Creates a new extension with the given name and requirements.
pub fn new_with_reqs(name: ExtensionId, extension_reqs: impl Into<ExtensionSet>) -> Self {
pub fn new(name: ExtensionId, version: Version) -> Self {
Self {
name,
extension_reqs: extension_reqs.into(),
version,
extension_reqs: Default::default(),
types: Default::default(),
values: Default::default(),
operations: Default::default(),
}
}

/// Extend the requirements of this extension with another set of extensions.
pub fn with_reqs(self, extension_reqs: impl Into<ExtensionSet>) -> Self {
Self {
extension_reqs: self.extension_reqs.union(extension_reqs.into()),
..self
}
}

/// Allows read-only access to the operations in this Extension
pub fn get_op(&self, name: &OpNameRef) -> Option<&Arc<op_def::OpDef>> {
self.operations.get(name)
Expand All @@ -321,6 +351,11 @@ impl Extension {
&self.name
}

/// Returns the version of the extension.
pub fn version(&self) -> &Version {
&self.version
}

/// Iterator over the operations of this [`Extension`].
pub fn operations(&self) -> impl Iterator<Item = (&OpName, &Arc<OpDef>)> {
self.operations.iter()
Expand Down Expand Up @@ -382,8 +417,8 @@ impl PartialEq for Extension {
#[derive(Debug, Clone, Error, PartialEq, Eq)]
pub enum ExtensionRegistryError {
/// Extension already defined.
#[error("The registry already contains an extension with id {0}.")]
AlreadyRegistered(ExtensionId),
#[error("The registry already contains an extension with id {0} and version {1}. New extension has version {2}.")]
AlreadyRegistered(ExtensionId, Version, Version),
/// A registered extension has invalid signatures.
#[error("The extension {0} contains an invalid signature, {1}.")]
InvalidSignature(ExtensionId, #[source] SignatureError),
Expand Down Expand Up @@ -544,6 +579,53 @@ pub mod test {
// We re-export this here because mod op_def is private.
pub use super::op_def::test::SimpleOpDef;

use super::*;

impl Extension {
/// Create a new extension for testing, with a 0 version.
pub(crate) fn new_test(name: ExtensionId) -> Self {
Self::new(name, Version::new(0, 0, 0))
}
}

#[test]
fn test_register_update() {
let mut reg = ExtensionRegistry::try_new([]).unwrap();
let ext_1_id = ExtensionId::new("ext1").unwrap();
let ext_2_id = ExtensionId::new("ext2").unwrap();
let ext1 = Extension::new(ext_1_id.clone(), Version::new(1, 0, 0));
let ext1_1 = Extension::new(ext_1_id.clone(), Version::new(1, 1, 0));
let ext1_2 = Extension::new(ext_1_id.clone(), Version::new(0, 2, 0));
let ext2 = Extension::new(ext_2_id, Version::new(1, 0, 0));

reg.register(ext1.clone()).unwrap();
assert_eq!(reg.get("ext1").unwrap().version(), &Version::new(1, 0, 0));

// normal registration fails
assert_eq!(
reg.register(ext1_1.clone()),
Err(ExtensionRegistryError::AlreadyRegistered(
ext_1_id.clone(),
Version::new(1, 0, 0),
Version::new(1, 1, 0)
))
);

// register with update works
reg.register_updated(ext1_1.clone()).unwrap();
assert_eq!(reg.get("ext1").unwrap().version(), &Version::new(1, 1, 0));

// register with lower version does not change version
reg.register_updated(ext1_2.clone()).unwrap();
assert_eq!(reg.get("ext1").unwrap().version(), &Version::new(1, 1, 0));

reg.register(ext2.clone()).unwrap();
assert_eq!(reg.get("ext2").unwrap().version(), &Version::new(1, 0, 0));
assert_eq!(reg.len(), 2);

assert!(reg.remove_extension(&ext_1_id).unwrap().version() == &Version::new(1, 1, 0));
assert_eq!(reg.len(), 1);
}
mod proptest {

use ::proptest::{collection::hash_set, prelude::*};
Expand Down
4 changes: 3 additions & 1 deletion hugr-core/src/extension/declarative.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ struct ExtensionSetDeclaration {
/// A declarative extension definition.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
struct ExtensionDeclaration {
// TODO add version
/// The name of the extension.
name: ExtensionId,
/// A list of types that this extension provides.
Expand Down Expand Up @@ -150,7 +151,8 @@ impl ExtensionDeclaration {
imports: &ExtensionSet,
ctx: DeclarationContext<'_>,
) -> Result<Extension, ExtensionDeclarationError> {
let mut ext = Extension::new_with_reqs(self.name.clone(), imports.clone());
let mut ext = Extension::new(self.name.clone(), crate::extension::Version::new(0, 0, 0))
.with_reqs(imports.clone());

for t in &self.types {
t.register(&mut ext, ctx)?;
Expand Down
8 changes: 4 additions & 4 deletions hugr-core/src/extension/op_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,7 @@ pub(super) mod test {
#[test]
fn op_def_with_type_scheme() -> Result<(), Box<dyn std::error::Error>> {
let list_def = EXTENSION.get_type(&LIST_TYPENAME).unwrap();
let mut e = Extension::new(EXT_ID);
let mut e = Extension::new_test(EXT_ID);
const TP: TypeParam = TypeParam::Type { b: TypeBound::Any };
let list_of_var =
Type::new_extension(list_def.instantiate(vec![TypeArg::new_var_use(0, TP)])?);
Expand Down Expand Up @@ -658,7 +658,7 @@ pub(super) mod test {
MAX_NAT
}
}
let mut e = Extension::new(EXT_ID);
let mut e = Extension::new_test(EXT_ID);
let def: &mut crate::extension::OpDef =
e.add_op("MyOp".into(), "".to_string(), SigFun())?;

Expand Down Expand Up @@ -720,7 +720,7 @@ pub(super) mod test {
fn type_scheme_instantiate_var() -> Result<(), Box<dyn std::error::Error>> {
// Check that we can instantiate a PolyFuncTypeRV-scheme with an (external)
// type variable
let mut e = Extension::new(EXT_ID);
let mut e = Extension::new_test(EXT_ID);
let def = e.add_op(
"SimpleOp".into(),
"".into(),
Expand Down Expand Up @@ -755,7 +755,7 @@ pub(super) mod test {
fn instantiate_extension_delta() -> Result<(), Box<dyn std::error::Error>> {
use crate::extension::prelude::{BOOL_T, PRELUDE_REGISTRY};

let mut e = Extension::new(EXT_ID);
let mut e = Extension::new_test(EXT_ID);

let params: Vec<TypeParam> = vec![TypeParam::Extensions];
let db_set = ExtensionSet::type_var(0);
Expand Down
4 changes: 3 additions & 1 deletion hugr-core/src/extension/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,11 @@ impl SignatureFromArgs for GenericOpCustom {

/// Name of prelude extension.
pub const PRELUDE_ID: ExtensionId = ExtensionId::new_unchecked("prelude");
/// Extension version.
pub const VERSION: semver::Version = semver::Version::new(0, 1, 0);
lazy_static! {
static ref PRELUDE_DEF: Extension = {
let mut prelude = Extension::new(PRELUDE_ID);
let mut prelude = Extension::new(PRELUDE_ID, VERSION);
prelude
.add_type(
TypeName::new_inline("usize"),
Expand Down
2 changes: 1 addition & 1 deletion hugr-core/src/extension/simple_op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ mod test {

lazy_static! {
static ref EXT: Extension = {
let mut e = Extension::new(EXT_ID.clone());
let mut e = Extension::new_test(EXT_ID.clone());
DummyEnum::Dumb.add_to_extension(&mut e).unwrap();
e
};
Expand Down
6 changes: 3 additions & 3 deletions hugr-core/src/hugr/validate/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ const_extension_ids! {
}
#[test]
fn invalid_types() {
let mut e = Extension::new(EXT_ID);
let mut e = Extension::new_test(EXT_ID);
e.add_type(
"MyContainer".into(),
vec![TypeBound::Copyable.into()],
Expand Down Expand Up @@ -570,7 +570,7 @@ fn no_polymorphic_consts() -> Result<(), Box<dyn std::error::Error>> {

pub(crate) fn extension_with_eval_parallel() -> Extension {
let rowp = TypeParam::new_list(TypeBound::Any);
let mut e = Extension::new(EXT_ID);
let mut e = Extension::new_test(EXT_ID);

let inputs = TypeRV::new_row_var_use(0, TypeBound::Any);
let outputs = TypeRV::new_row_var_use(1, TypeBound::Any);
Expand Down Expand Up @@ -671,7 +671,7 @@ fn row_variables() -> Result<(), Box<dyn std::error::Error>> {

#[test]
fn test_polymorphic_call() -> Result<(), Box<dyn std::error::Error>> {
let mut e = Extension::new(EXT_ID);
let mut e = Extension::new_test(EXT_ID);

let params: Vec<TypeParam> = vec![
TypeBound::Any.into(),
Expand Down
5 changes: 4 additions & 1 deletion hugr-core/src/std_extensions/arithmetic/conversions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ use lazy_static::lazy_static;
mod const_fold;
/// The extension identifier.
pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.conversions");
/// Extension version.
pub const VERSION: semver::Version = semver::Version::new(0, 1, 0);

/// Extension for conversions between floats and integers.
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)]
Expand Down Expand Up @@ -121,8 +123,9 @@ impl MakeExtensionOp for ConvertOpType {
lazy_static! {
/// Extension for conversions between integers and floats.
pub static ref EXTENSION: Extension = {
let mut extension = Extension::new_with_reqs(
let mut extension = Extension::new(
EXTENSION_ID,
VERSION).with_reqs(
ExtensionSet::from_iter(vec![
super::int_types::EXTENSION_ID,
super::float_types::EXTENSION_ID,
Expand Down
5 changes: 4 additions & 1 deletion hugr-core/src/std_extensions/arithmetic/float_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ use lazy_static::lazy_static;
mod const_fold;
/// The extension identifier.
pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.float");
/// Extension version.
pub const VERSION: semver::Version = semver::Version::new(0, 1, 0);

/// Integer extension operation definitions.
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)]
Expand Down Expand Up @@ -99,8 +101,9 @@ impl MakeOpDef for FloatOps {
lazy_static! {
/// Extension for basic float operations.
pub static ref EXTENSION: Extension = {
let mut extension = Extension::new_with_reqs(
let mut extension = Extension::new(
EXTENSION_ID,
VERSION).with_reqs(
ExtensionSet::singleton(&super::int_types::EXTENSION_ID),
);

Expand Down
4 changes: 3 additions & 1 deletion hugr-core/src/std_extensions/arithmetic/float_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ use lazy_static::lazy_static;

/// The extension identifier.
pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.float.types");
/// Extension version.
pub const VERSION: semver::Version = semver::Version::new(0, 1, 0);

/// Identifier for the 64-bit IEEE 754-2019 floating-point type.
const FLOAT_TYPE_ID: TypeName = TypeName::new_inline("float64");
Expand Down Expand Up @@ -76,7 +78,7 @@ impl CustomConst for ConstF64 {
lazy_static! {
/// Extension defining the float type.
pub static ref EXTENSION: Extension = {
let mut extension = Extension::new(EXTENSION_ID);
let mut extension = Extension::new(EXTENSION_ID, VERSION);

extension
.add_type(
Expand Down
7 changes: 5 additions & 2 deletions hugr-core/src/std_extensions/arithmetic/int_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ mod const_fold;

/// The extension identifier.
pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.int");
/// Extension version.
pub const VERSION: semver::Version = semver::Version::new(0, 1, 0);

struct IOValidator {
// whether the first type argument should be greater than or equal to the second
Expand Down Expand Up @@ -261,9 +263,10 @@ fn iunop_sig() -> PolyFuncTypeRV {
lazy_static! {
/// Extension for basic integer operations.
pub static ref EXTENSION: Extension = {
let mut extension = Extension::new_with_reqs(
let mut extension = Extension::new(
EXTENSION_ID,
ExtensionSet::singleton(&super::int_types::EXTENSION_ID),
VERSION).with_reqs(
ExtensionSet::singleton(&super::int_types::EXTENSION_ID)
);

IntOpDef::load_all_ops(&mut extension).unwrap();
Expand Down
Loading
Loading