From 56e629eb5ca79433265b0896599556989b11708e Mon Sep 17 00:00:00 2001 From: Jakob Kraus Date: Thu, 17 Oct 2024 16:55:29 +0200 Subject: [PATCH] refactor: replace outgoing_edges and incoming_edges with edges --- .../src/medrecord/querying/nodes/operand.rs | 23 ++----- .../src/medrecord/querying/nodes/operation.rs | 68 ++++++++----------- medmodels/_medmodels.pyi | 3 +- medmodels/medrecord/querying.py | 9 ++- rustmodels/src/medrecord/querying/nodes.rs | 8 +-- 5 files changed, 40 insertions(+), 71 deletions(-) diff --git a/crates/medmodels-core/src/medrecord/querying/nodes/operand.rs b/crates/medmodels-core/src/medrecord/querying/nodes/operand.rs index 329dc8f..2d730f8 100644 --- a/crates/medmodels-core/src/medrecord/querying/nodes/operand.rs +++ b/crates/medmodels-core/src/medrecord/querying/nodes/operand.rs @@ -110,21 +110,12 @@ impl NodeOperand { }); } - pub fn outgoing_edges(&mut self) -> Wrapper { + pub fn edges(&mut self, direction: EdgeDirection) -> Wrapper { let operand = Wrapper::::new(); - self.operations.push(NodeOperation::OutgoingEdges { - operand: operand.clone(), - }); - - operand - } - - pub fn incoming_edges(&mut self) -> Wrapper { - let operand = Wrapper::::new(); - - self.operations.push(NodeOperation::IncomingEdges { + self.operations.push(NodeOperation::Edges { operand: operand.clone(), + direction, }); operand @@ -211,12 +202,8 @@ impl Wrapper { self.0.write_or_panic().has_attribute(attribute); } - pub fn outgoing_edges(&mut self) -> Wrapper { - self.0.write_or_panic().outgoing_edges() - } - - pub fn incoming_edges(&mut self) -> Wrapper { - self.0.write_or_panic().incoming_edges() + pub fn edges(&mut self, direction: EdgeDirection) -> Wrapper { + self.0.write_or_panic().edges(direction) } pub fn neighbors(&mut self, direction: EdgeDirection) -> Wrapper { diff --git a/crates/medmodels-core/src/medrecord/querying/nodes/operation.rs b/crates/medmodels-core/src/medrecord/querying/nodes/operation.rs index 7db9544..8e9d040 100644 --- a/crates/medmodels-core/src/medrecord/querying/nodes/operation.rs +++ b/crates/medmodels-core/src/medrecord/querying/nodes/operation.rs @@ -58,11 +58,9 @@ pub enum NodeOperation { attribute: CardinalityWrapper, }, - OutgoingEdges { - operand: Wrapper, - }, - IncomingEdges { + Edges { operand: Wrapper, + direction: EdgeDirection, }, Neighbors { @@ -97,11 +95,9 @@ impl DeepClone for NodeOperation { Self::HasAttribute { attribute } => Self::HasAttribute { attribute: attribute.clone(), }, - Self::OutgoingEdges { operand } => Self::OutgoingEdges { - operand: operand.deep_clone(), - }, - Self::IncomingEdges { operand } => Self::IncomingEdges { + Self::Edges { operand, direction } => Self::Edges { operand: operand.deep_clone(), + direction: direction.clone(), }, Self::Neighbors { operand, @@ -153,15 +149,11 @@ impl NodeOperation { node_indices, attribute.clone(), )), - Self::OutgoingEdges { operand } => Box::new(Self::evaluate_outgoing_edges( - medrecord, - node_indices, - operand.clone(), - )?), - Self::IncomingEdges { operand } => Box::new(Self::evaluate_incoming_edges( + Self::Edges { operand, direction } => Box::new(Self::evaluate_edges( medrecord, node_indices, operand.clone(), + direction.clone(), )?), Self::Neighbors { operand, @@ -315,40 +307,36 @@ impl NodeOperation { } #[inline] - fn evaluate_outgoing_edges<'a>( - medrecord: &'a MedRecord, - node_indices: impl Iterator, - operand: Wrapper, - ) -> MedRecordResult> { - let edge_indices = operand.evaluate(medrecord)?.collect::(); - - Ok(node_indices.filter(move |node_index| { - let outgoing_edge_indices = medrecord - .outgoing_edges(node_index) - .expect("Node must exist"); - - let outgoing_edge_indices = outgoing_edge_indices.collect::(); - - !outgoing_edge_indices.is_disjoint(&edge_indices) - })) - } - - #[inline] - fn evaluate_incoming_edges<'a>( + fn evaluate_edges<'a>( medrecord: &'a MedRecord, node_indices: impl Iterator, operand: Wrapper, + direction: EdgeDirection, ) -> MedRecordResult> { let edge_indices = operand.evaluate(medrecord)?.collect::(); Ok(node_indices.filter(move |node_index| { - let incoming_edge_indices = medrecord - .incoming_edges(node_index) - .expect("Node must exist"); - - let incoming_edge_indices = incoming_edge_indices.collect::(); + let connected_indices = match direction { + EdgeDirection::Incoming => medrecord + .outgoing_edges(node_index) + .expect("Node must exist") + .collect::(), + EdgeDirection::Outgoing => medrecord + .incoming_edges(node_index) + .expect("Node must exist") + .collect::(), + EdgeDirection::Both => medrecord + .incoming_edges(node_index) + .expect("Node must exist") + .chain( + medrecord + .outgoing_edges(node_index) + .expect("Node must exist"), + ) + .collect::(), + }; - !incoming_edge_indices.is_disjoint(&edge_indices) + !connected_indices.is_disjoint(&edge_indices) })) } diff --git a/medmodels/_medmodels.pyi b/medmodels/_medmodels.pyi index 55f090b..fe36589 100644 --- a/medmodels/_medmodels.pyi +++ b/medmodels/_medmodels.pyi @@ -248,8 +248,7 @@ class PyNodeOperand: def has_attribute( self, attribute: Union[MedRecordAttribute, List[MedRecordAttribute]] ) -> None: ... - def outgoing_edges(self) -> PyEdgeOperand: ... - def incoming_edges(self) -> PyEdgeOperand: ... + def edges(self, direction: PyEdgeDirection) -> PyEdgeOperand: ... def neighbors(self, direction: PyEdgeDirection) -> PyNodeOperand: ... def either_or( self, diff --git a/medmodels/medrecord/querying.py b/medmodels/medrecord/querying.py index f274dd5..75c05bf 100644 --- a/medmodels/medrecord/querying.py +++ b/medmodels/medrecord/querying.py @@ -173,11 +173,10 @@ def has_attribute( ) -> None: self._node_operand.has_attribute(attribute) - def outgoing_edges(self) -> EdgeOperand: - return EdgeOperand._from_py_edge_operand(self._node_operand.outgoing_edges()) - - def incoming_edges(self) -> EdgeOperand: - return EdgeOperand._from_py_edge_operand(self._node_operand.incoming_edges()) + def edges(self, direction: EdgeDirection = EdgeDirection.BOTH) -> EdgeOperand: + return EdgeOperand._from_py_edge_operand( + self._node_operand.edges(direction._into_py_edge_direction()) + ) def neighbors( self, edge_direction: EdgeDirection = EdgeDirection.OUTGOING diff --git a/rustmodels/src/medrecord/querying/nodes.rs b/rustmodels/src/medrecord/querying/nodes.rs index ec4120f..5370eea 100644 --- a/rustmodels/src/medrecord/querying/nodes.rs +++ b/rustmodels/src/medrecord/querying/nodes.rs @@ -82,12 +82,8 @@ impl PyNodeOperand { self.0.has_attribute(attribute); } - pub fn outgoing_edges(&mut self) -> PyEdgeOperand { - self.0.outgoing_edges().into() - } - - pub fn incoming_edges(&mut self) -> PyEdgeOperand { - self.0.incoming_edges().into() + pub fn edges(&mut self, direction: PyEdgeDirection) -> PyEdgeOperand { + self.0.edges(direction.into()).into() } pub fn neighbors(&mut self, direction: PyEdgeDirection) -> PyNodeOperand {