diff --git a/hugr-core/src/extension.rs b/hugr-core/src/extension.rs index 058d48aa4..c5da85276 100644 --- a/hugr-core/src/extension.rs +++ b/hugr-core/src/extension.rs @@ -58,10 +58,13 @@ impl ExtensionRegistry { ) -> Result { let mut exts = BTreeMap::new(); for ext in value.into_iter() { + let ext_v = ext.version().clone(); let prev = exts.insert(ext.name.clone(), ext); if let Some(prev) = prev { return Err(ExtensionRegistryError::AlreadyRegistered( prev.name().clone(), + prev.version().clone(), + ext_v, )); }; } @@ -83,13 +86,35 @@ impl ExtensionRegistry { /// Returns a reference to the registered extension if successful. pub fn register(&mut self, extension: Extension) -> Result<&Extension, ExtensionRegistryError> { match self.0.entry(extension.name().clone()) { - btree_map::Entry::Occupied(_) => Err(ExtensionRegistryError::AlreadyRegistered( + btree_map::Entry::Occupied(prev) => Err(ExtensionRegistryError::AlreadyRegistered( extension.name().clone(), + prev.get().version().clone(), + extension.version().clone(), )), btree_map::Entry::Vacant(ve) => Ok(ve.insert(extension)), } } + /// Registers a new extension to the registry, keeping most up to date if extension exists. + /// + /// If extension IDs match, the extension with the higher version is kept. + /// If versions match, the original extension is kept. + /// Returns a reference to the registered extension if successful. + pub fn register_updated( + &mut self, + extension: Extension, + ) -> Result<&Extension, ExtensionRegistryError> { + match self.0.entry(extension.name().clone()) { + btree_map::Entry::Occupied(mut prev) => { + if prev.get().version() < extension.version() { + *prev.get_mut() = extension; + } + Ok(prev.into_mut()) + } + btree_map::Entry::Vacant(ve) => Ok(ve.insert(extension)), + } + } + /// Returns the number of extensions in the registry. pub fn len(&self) -> usize { self.0.len() @@ -328,6 +353,11 @@ impl Extension { &self.name } + /// Returns the version of the extension. + pub fn version(&self) -> &Version { + &self.version + } + /// Iterator over the operations of this [`Extension`]. pub fn operations(&self) -> impl Iterator)> { self.operations.iter() @@ -389,8 +419,8 @@ impl PartialEq for Extension { #[derive(Debug, Clone, Error, PartialEq, Eq)] pub enum ExtensionRegistryError { /// Extension already defined. - #[error("The registry already contains an extension with id {0}.")] - AlreadyRegistered(ExtensionId), + #[error("The registry already contains an extension with id {0} and version {1}. New extension has version {2}.")] + AlreadyRegistered(ExtensionId, Version, Version), /// A registered extension has invalid signatures. #[error("The extension {0} contains an invalid signature, {1}.")] InvalidSignature(ExtensionId, #[source] SignatureError), @@ -551,6 +581,42 @@ pub mod test { // We re-export this here because mod op_def is private. pub use super::op_def::test::SimpleOpDef; + use super::*; + #[test] + fn test_register_update() { + let mut reg = ExtensionRegistry::try_new([]).unwrap(); + let ext_1_id = ExtensionId::new("ext1").unwrap(); + let ext_2_id = ExtensionId::new("ext2").unwrap(); + let ext1 = Extension::new(ext_1_id.clone(), Version::new(1, 0, 0)); + let ext1_1 = Extension::new(ext_1_id.clone(), Version::new(1, 1, 0)); + let ext1_2 = Extension::new(ext_1_id.clone(), Version::new(0, 2, 0)); + let ext2 = Extension::new(ext_2_id, Version::new(1, 0, 0)); + + reg.register(ext1.clone()).unwrap(); + assert_eq!(reg.get("ext1").unwrap().version(), &Version::new(1, 0, 0)); + + // normal registration fails + assert_eq!( + reg.register(ext1_1.clone()), + Err(ExtensionRegistryError::AlreadyRegistered( + ext_1_id, + Version::new(1, 0, 0), + Version::new(1, 1, 0) + )) + ); + + // register with update works + reg.register_updated(ext1_1.clone()).unwrap(); + assert_eq!(reg.get("ext1").unwrap().version(), &Version::new(1, 1, 0)); + + // register with lower version does not change version + reg.register_updated(ext1_2.clone()).unwrap(); + assert_eq!(reg.get("ext1").unwrap().version(), &Version::new(1, 1, 0)); + + reg.register(ext2.clone()).unwrap(); + assert_eq!(reg.get("ext2").unwrap().version(), &Version::new(1, 0, 0)); + assert_eq!(reg.len(), 2); + } mod proptest { use ::proptest::{collection::hash_set, prelude::*};