Skip to content

Commit

Permalink
allow updating extensions with higher version in registry.
Browse files Browse the repository at this point in the history
  • Loading branch information
ss2165 committed Jul 25, 2024
1 parent d60b0c7 commit 6c87e96
Showing 1 changed file with 69 additions and 3 deletions.
72 changes: 69 additions & 3 deletions hugr-core/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,13 @@ impl ExtensionRegistry {
) -> Result<Self, ExtensionRegistryError> {
let mut exts = BTreeMap::new();
for ext in value.into_iter() {
let ext_v = ext.version().clone();
let prev = exts.insert(ext.name.clone(), ext);
if let Some(prev) = prev {
return Err(ExtensionRegistryError::AlreadyRegistered(
prev.name().clone(),
prev.version().clone(),
ext_v,
));
};
}
Expand All @@ -83,13 +86,35 @@ impl ExtensionRegistry {
/// Returns a reference to the registered extension if successful.
pub fn register(&mut self, extension: Extension) -> Result<&Extension, ExtensionRegistryError> {
match self.0.entry(extension.name().clone()) {
btree_map::Entry::Occupied(_) => Err(ExtensionRegistryError::AlreadyRegistered(
btree_map::Entry::Occupied(prev) => Err(ExtensionRegistryError::AlreadyRegistered(
extension.name().clone(),
prev.get().version().clone(),
extension.version().clone(),
)),
btree_map::Entry::Vacant(ve) => Ok(ve.insert(extension)),
}
}

/// Registers a new extension to the registry, keeping most up to date if extension exists.
///
/// If extension IDs match, the extension with the higher version is kept.
/// If versions match, the original extension is kept.
/// Returns a reference to the registered extension if successful.
pub fn register_updated(
&mut self,
extension: Extension,
) -> Result<&Extension, ExtensionRegistryError> {
match self.0.entry(extension.name().clone()) {
btree_map::Entry::Occupied(mut prev) => {
if prev.get().version() < extension.version() {
*prev.get_mut() = extension;
}
Ok(prev.into_mut())
}
btree_map::Entry::Vacant(ve) => Ok(ve.insert(extension)),
}
}

/// Returns the number of extensions in the registry.
pub fn len(&self) -> usize {
self.0.len()
Expand Down Expand Up @@ -328,6 +353,11 @@ impl Extension {
&self.name
}

/// Returns the version of the extension.
pub fn version(&self) -> &Version {
&self.version
}

/// Iterator over the operations of this [`Extension`].
pub fn operations(&self) -> impl Iterator<Item = (&OpName, &Arc<OpDef>)> {
self.operations.iter()
Expand Down Expand Up @@ -389,8 +419,8 @@ impl PartialEq for Extension {
#[derive(Debug, Clone, Error, PartialEq, Eq)]
pub enum ExtensionRegistryError {
/// Extension already defined.
#[error("The registry already contains an extension with id {0}.")]
AlreadyRegistered(ExtensionId),
#[error("The registry already contains an extension with id {0} and version {1}. New extension has version {2}.")]
AlreadyRegistered(ExtensionId, Version, Version),
/// A registered extension has invalid signatures.
#[error("The extension {0} contains an invalid signature, {1}.")]
InvalidSignature(ExtensionId, #[source] SignatureError),
Expand Down Expand Up @@ -551,6 +581,42 @@ pub mod test {
// We re-export this here because mod op_def is private.
pub use super::op_def::test::SimpleOpDef;

use super::*;
#[test]
fn test_register_update() {
let mut reg = ExtensionRegistry::try_new([]).unwrap();
let ext_1_id = ExtensionId::new("ext1").unwrap();
let ext_2_id = ExtensionId::new("ext2").unwrap();
let ext1 = Extension::new(ext_1_id.clone(), Version::new(1, 0, 0));
let ext1_1 = Extension::new(ext_1_id.clone(), Version::new(1, 1, 0));
let ext1_2 = Extension::new(ext_1_id.clone(), Version::new(0, 2, 0));
let ext2 = Extension::new(ext_2_id, Version::new(1, 0, 0));

reg.register(ext1.clone()).unwrap();
assert_eq!(reg.get("ext1").unwrap().version(), &Version::new(1, 0, 0));

// normal registration fails
assert_eq!(
reg.register(ext1_1.clone()),
Err(ExtensionRegistryError::AlreadyRegistered(
ext_1_id,
Version::new(1, 0, 0),
Version::new(1, 1, 0)
))
);

// register with update works
reg.register_updated(ext1_1.clone()).unwrap();
assert_eq!(reg.get("ext1").unwrap().version(), &Version::new(1, 1, 0));

// register with lower version does not change version
reg.register_updated(ext1_2.clone()).unwrap();
assert_eq!(reg.get("ext1").unwrap().version(), &Version::new(1, 1, 0));

reg.register(ext2.clone()).unwrap();
assert_eq!(reg.get("ext2").unwrap().version(), &Version::new(1, 0, 0));
assert_eq!(reg.len(), 2);
}
mod proptest {

use ::proptest::{collection::hash_set, prelude::*};
Expand Down

0 comments on commit 6c87e96

Please sign in to comment.