Skip to content

Commit

Permalink
refactor: replace outgoing_edges and incoming_edges with edges
Browse files Browse the repository at this point in the history
  • Loading branch information
JabobKrauskopf committed Oct 17, 2024
1 parent a183c40 commit 56e629e
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 71 deletions.
23 changes: 5 additions & 18 deletions crates/medmodels-core/src/medrecord/querying/nodes/operand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,21 +110,12 @@ impl NodeOperand {
});
}

pub fn outgoing_edges(&mut self) -> Wrapper<EdgeOperand> {
pub fn edges(&mut self, direction: EdgeDirection) -> Wrapper<EdgeOperand> {
let operand = Wrapper::<EdgeOperand>::new();

self.operations.push(NodeOperation::OutgoingEdges {
operand: operand.clone(),
});

operand
}

pub fn incoming_edges(&mut self) -> Wrapper<EdgeOperand> {
let operand = Wrapper::<EdgeOperand>::new();

self.operations.push(NodeOperation::IncomingEdges {
self.operations.push(NodeOperation::Edges {
operand: operand.clone(),
direction,
});

operand
Expand Down Expand Up @@ -211,12 +202,8 @@ impl Wrapper<NodeOperand> {
self.0.write_or_panic().has_attribute(attribute);
}

pub fn outgoing_edges(&mut self) -> Wrapper<EdgeOperand> {
self.0.write_or_panic().outgoing_edges()
}

pub fn incoming_edges(&mut self) -> Wrapper<EdgeOperand> {
self.0.write_or_panic().incoming_edges()
pub fn edges(&mut self, direction: EdgeDirection) -> Wrapper<EdgeOperand> {
self.0.write_or_panic().edges(direction)
}

pub fn neighbors(&mut self, direction: EdgeDirection) -> Wrapper<NodeOperand> {
Expand Down
68 changes: 28 additions & 40 deletions crates/medmodels-core/src/medrecord/querying/nodes/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,9 @@ pub enum NodeOperation {
attribute: CardinalityWrapper<MedRecordAttribute>,
},

OutgoingEdges {
operand: Wrapper<EdgeOperand>,
},
IncomingEdges {
Edges {
operand: Wrapper<EdgeOperand>,
direction: EdgeDirection,
},

Neighbors {
Expand Down Expand Up @@ -97,11 +95,9 @@ impl DeepClone for NodeOperation {
Self::HasAttribute { attribute } => Self::HasAttribute {
attribute: attribute.clone(),
},
Self::OutgoingEdges { operand } => Self::OutgoingEdges {
operand: operand.deep_clone(),
},
Self::IncomingEdges { operand } => Self::IncomingEdges {
Self::Edges { operand, direction } => Self::Edges {
operand: operand.deep_clone(),
direction: direction.clone(),
},
Self::Neighbors {
operand,
Expand Down Expand Up @@ -153,15 +149,11 @@ impl NodeOperation {
node_indices,
attribute.clone(),
)),
Self::OutgoingEdges { operand } => Box::new(Self::evaluate_outgoing_edges(
medrecord,
node_indices,
operand.clone(),
)?),
Self::IncomingEdges { operand } => Box::new(Self::evaluate_incoming_edges(
Self::Edges { operand, direction } => Box::new(Self::evaluate_edges(
medrecord,
node_indices,
operand.clone(),
direction.clone(),
)?),
Self::Neighbors {
operand,
Expand Down Expand Up @@ -315,40 +307,36 @@ impl NodeOperation {
}

#[inline]
fn evaluate_outgoing_edges<'a>(
medrecord: &'a MedRecord,
node_indices: impl Iterator<Item = &'a NodeIndex>,
operand: Wrapper<EdgeOperand>,
) -> MedRecordResult<impl Iterator<Item = &'a NodeIndex>> {
let edge_indices = operand.evaluate(medrecord)?.collect::<RoaringBitmap>();

Ok(node_indices.filter(move |node_index| {
let outgoing_edge_indices = medrecord
.outgoing_edges(node_index)
.expect("Node must exist");

let outgoing_edge_indices = outgoing_edge_indices.collect::<RoaringBitmap>();

!outgoing_edge_indices.is_disjoint(&edge_indices)
}))
}

#[inline]
fn evaluate_incoming_edges<'a>(
fn evaluate_edges<'a>(
medrecord: &'a MedRecord,
node_indices: impl Iterator<Item = &'a NodeIndex>,
operand: Wrapper<EdgeOperand>,
direction: EdgeDirection,
) -> MedRecordResult<impl Iterator<Item = &'a NodeIndex>> {
let edge_indices = operand.evaluate(medrecord)?.collect::<RoaringBitmap>();

Ok(node_indices.filter(move |node_index| {
let incoming_edge_indices = medrecord
.incoming_edges(node_index)
.expect("Node must exist");

let incoming_edge_indices = incoming_edge_indices.collect::<RoaringBitmap>();
let connected_indices = match direction {
EdgeDirection::Incoming => medrecord
.outgoing_edges(node_index)
.expect("Node must exist")
.collect::<RoaringBitmap>(),
EdgeDirection::Outgoing => medrecord
.incoming_edges(node_index)
.expect("Node must exist")
.collect::<RoaringBitmap>(),
EdgeDirection::Both => medrecord
.incoming_edges(node_index)
.expect("Node must exist")
.chain(
medrecord
.outgoing_edges(node_index)
.expect("Node must exist"),
)
.collect::<RoaringBitmap>(),
};

!incoming_edge_indices.is_disjoint(&edge_indices)
!connected_indices.is_disjoint(&edge_indices)
}))
}

Expand Down
3 changes: 1 addition & 2 deletions medmodels/_medmodels.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,7 @@ class PyNodeOperand:
def has_attribute(
self, attribute: Union[MedRecordAttribute, List[MedRecordAttribute]]
) -> None: ...
def outgoing_edges(self) -> PyEdgeOperand: ...
def incoming_edges(self) -> PyEdgeOperand: ...
def edges(self, direction: PyEdgeDirection) -> PyEdgeOperand: ...
def neighbors(self, direction: PyEdgeDirection) -> PyNodeOperand: ...
def either_or(
self,
Expand Down
9 changes: 4 additions & 5 deletions medmodels/medrecord/querying.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,10 @@ def has_attribute(
) -> None:
self._node_operand.has_attribute(attribute)

def outgoing_edges(self) -> EdgeOperand:
return EdgeOperand._from_py_edge_operand(self._node_operand.outgoing_edges())

def incoming_edges(self) -> EdgeOperand:
return EdgeOperand._from_py_edge_operand(self._node_operand.incoming_edges())
def edges(self, direction: EdgeDirection = EdgeDirection.BOTH) -> EdgeOperand:
return EdgeOperand._from_py_edge_operand(
self._node_operand.edges(direction._into_py_edge_direction())
)

def neighbors(
self, edge_direction: EdgeDirection = EdgeDirection.OUTGOING
Expand Down
8 changes: 2 additions & 6 deletions rustmodels/src/medrecord/querying/nodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,8 @@ impl PyNodeOperand {
self.0.has_attribute(attribute);
}

pub fn outgoing_edges(&mut self) -> PyEdgeOperand {
self.0.outgoing_edges().into()
}

pub fn incoming_edges(&mut self) -> PyEdgeOperand {
self.0.incoming_edges().into()
pub fn edges(&mut self, direction: PyEdgeDirection) -> PyEdgeOperand {
self.0.edges(direction.into()).into()
}

pub fn neighbors(&mut self, direction: PyEdgeDirection) -> PyNodeOperand {
Expand Down

0 comments on commit 56e629e

Please sign in to comment.