Skip to content

Commit

Permalink
refactor: rust query engine
Browse files Browse the repository at this point in the history
  • Loading branch information
JabobKrauskopf committed Sep 17, 2024
1 parent cd391ec commit 5df1278
Show file tree
Hide file tree
Showing 31 changed files with 1,965 additions and 3,558 deletions.
21 changes: 19 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 4 additions & 7 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,9 @@ description = "Limebit MedModels Crate"
[workspace.dependencies]
hashbrown = { version = "0.14.5", features = ["serde"] }
serde = { version = "1.0.203", features = ["derive"] }
ron = "0.8.1"
chrono = { version = "0.4.38", features = ["serde"] }
pyo3 = { version = "0.21.2", features = ["chrono"] }
polars = { version = "0.40.0", features = ["polars-io"] }
pyo3-polars = "0.14.0"
chrono = { version = "0.4.38", features = ["serde"] }

medmodels = { version = "0.1.2", path = "crates/medmodels" }
medmodels-core = { version = "0.1.2", path = "crates/medmodels-core" }
medmodels-utils = { version = "0.1.2", path = "crates/medmodels-utils" }
medmodels = { version = "0.1.1", path = "crates/medmodels" }
medmodels-core = { version = "0.1.1", path = "crates/medmodels-core" }
medmodels-utils = { version = "0.1.1", path = "crates/medmodels-utils" }
3 changes: 2 additions & 1 deletion crates/medmodels-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@ medmodels-utils = { workspace = true }

polars = { workspace = true }
serde = { workspace = true }
ron = { workspace = true }
chrono = { workspace = true }
ron = "0.8.1"
roaring = "0.10.6"
3 changes: 3 additions & 0 deletions crates/medmodels-core/src/medrecord/datatypes/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
#![allow(dead_code)]
// TODO: Remove the above line after query engine is implemented

mod attribute;
mod value;

Expand Down
11 changes: 6 additions & 5 deletions crates/medmodels-core/src/medrecord/example_dataset/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,16 +71,16 @@ impl MedRecord {
.into_reader_with_file_handle(cursor)
.finish()
.expect("DataFrame can be built");
let patient_diagnosis_ids = (0..patient_diagnosis.height()).collect::<Vec<_>>();
let patient_diagnosis_ids = (0..patient_diagnosis.height() as u32).collect::<Vec<_>>();

let cursor = Cursor::new(PATIENT_DRUG);
let patient_drug = CsvReadOptions::default()
.with_has_header(true)
.into_reader_with_file_handle(cursor)
.finish()
.expect("DataFrame can be built");
let patient_drug_ids = (patient_diagnosis.height()
..patient_diagnosis.height() + patient_drug.height())
let patient_drug_ids = (patient_diagnosis.height() as u32
..(patient_diagnosis.height() + patient_drug.height()) as u32)
.collect::<Vec<_>>();

let cursor = Cursor::new(PATIENT_PROCEDURE);
Expand All @@ -89,8 +89,9 @@ impl MedRecord {
.into_reader_with_file_handle(cursor)
.finish()
.expect("DataFrame can be built");
let patient_procedure_ids = (patient_diagnosis.height() + patient_drug.height()
..patient_diagnosis.height() + patient_drug.height() + patient_procedure.height())
let patient_procedure_ids = ((patient_diagnosis.height() + patient_drug.height()) as u32
..(patient_diagnosis.height() + patient_drug.height() + patient_procedure.height())
as u32)
.collect::<Vec<_>>();

let mut medrecord = Self::from_dataframes(
Expand Down
6 changes: 3 additions & 3 deletions crates/medmodels-core/src/medrecord/graph/edge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Edge {
pub attributes: Attributes,
pub(super) source_node_index: NodeIndex,
pub(super) target_node_index: NodeIndex,
pub(crate) attributes: Attributes,
pub(crate) source_node_index: NodeIndex,
pub(crate) target_node_index: NodeIndex,
}

impl Edge {
Expand Down
14 changes: 7 additions & 7 deletions crates/medmodels-core/src/medrecord/graph/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,18 @@ use node::Node;
use serde::{Deserialize, Serialize};
use std::{
collections::{HashMap, HashSet},
sync::atomic::AtomicUsize,
sync::atomic::AtomicU32,
};

pub type NodeIndex = MedRecordAttribute;
pub type EdgeIndex = usize;
pub type EdgeIndex = u32;
pub type Attributes = HashMap<MedRecordAttribute, MedRecordValue>;

#[derive(Serialize, Deserialize, Debug)]
pub(super) struct Graph {
pub(crate) nodes: MrHashMap<NodeIndex, Node>,
pub(crate) edges: MrHashMap<EdgeIndex, Edge>,
edge_index_counter: AtomicUsize,
edge_index_counter: AtomicU32,
}

#[allow(dead_code)]
Expand All @@ -29,29 +29,29 @@ impl Graph {
Self {
nodes: MrHashMap::new(),
edges: MrHashMap::new(),
edge_index_counter: AtomicUsize::new(0),
edge_index_counter: AtomicU32::new(0),
}
}

pub fn with_capacity(node_capacity: usize, edge_capacity: usize) -> Self {
Self {
nodes: MrHashMap::with_capacity(node_capacity),
edges: MrHashMap::with_capacity(edge_capacity),
edge_index_counter: AtomicUsize::new(0),
edge_index_counter: AtomicU32::new(0),
}
}

pub fn clear(&mut self) {
self.nodes.clear();
self.edges.clear();

self.edge_index_counter = AtomicUsize::new(0);
self.edge_index_counter = AtomicU32::new(0);
}

pub fn clear_edges(&mut self) {
self.edges.clear();

self.edge_index_counter = AtomicUsize::new(0);
self.edge_index_counter = AtomicU32::new(0);
}

pub fn node_count(&self) -> usize {
Expand Down
6 changes: 3 additions & 3 deletions crates/medmodels-core/src/medrecord/graph/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Node {
pub attributes: Attributes,
pub(super) outgoing_edge_indices: MrHashSet<EdgeIndex>,
pub(super) incoming_edge_indices: MrHashSet<EdgeIndex>,
pub(crate) attributes: Attributes,
pub(crate) outgoing_edge_indices: MrHashSet<EdgeIndex>,
pub(crate) incoming_edge_indices: MrHashSet<EdgeIndex>,
}

impl Node {
Expand Down
78 changes: 70 additions & 8 deletions crates/medmodels-core/src/medrecord/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ pub use self::{
graph::{Attributes, EdgeIndex, NodeIndex},
group_mapping::Group,
querying::{
edge, node, ArithmeticOperation, EdgeAttributeOperand, EdgeIndexOperand, EdgeOperand,
EdgeOperation, NodeAttributeOperand, NodeIndexOperand, NodeOperand, NodeOperation,
TransformationOperation, ValueOperand,
edges::{EdgeOperand, EdgeValueOperand, EdgeValuesOperand},
nodes::{NodeOperand, NodeValueOperand, NodeValuesOperand},
values::{ComparisonOperand, ValueOperand, ValuesOperand},
wrapper::{CardinalityWrapper, Wrapper},
},
schema::{AttributeDataType, AttributeType, GroupSchema, Schema},
};
Expand All @@ -22,7 +23,7 @@ use ::polars::frame::DataFrame;
use graph::Graph;
use group_mapping::GroupMapping;
use polars::{dataframe_to_edges, dataframe_to_nodes};
use querying::{EdgeSelection, NodeSelection};
use querying::{edges::EdgeSelection, nodes::NodeSelection};
use serde::{Deserialize, Serialize};
use std::{fs, mem, path::Path};

Expand Down Expand Up @@ -706,12 +707,18 @@ impl MedRecord {
self.group_mapping.clear();
}

pub fn select_nodes(&self, operation: NodeOperation) -> NodeSelection {
NodeSelection::new(self, operation)
pub fn select_nodes<Q>(&self, query: Q) -> NodeSelection
where
Q: FnOnce(&mut Wrapper<NodeOperand>),
{
NodeSelection::new(self, query)
}

pub fn select_edges(&self, operation: EdgeOperation) -> EdgeSelection {
EdgeSelection::new(self, operation)
pub fn select_edges<Q>(&self, query: Q) -> EdgeSelection
where
Q: FnOnce(&mut Wrapper<EdgeOperand>),
{
EdgeSelection::new(self, query)
}
}

Expand Down Expand Up @@ -1889,4 +1896,59 @@ mod test {
assert_eq!(0, medrecord.edge_count());
assert_eq!(0, medrecord.group_count());
}

#[test]
fn test_test() {
let nodes = vec![
("0".into(), HashMap::new()),
("1".into(), HashMap::new()),
("2".into(), HashMap::new()),
("3".into(), HashMap::new()),
];

let edges = vec![
(
"0".into(),
"1".into(),
HashMap::from([("time".into(), 0.into())]),
),
(
"0".into(),
"1".into(),
HashMap::from([("time".into(), 2.into())]),
),
(
"0".into(),
"1".into(),
HashMap::from([("time".into(), 3.into())]),
),
(
"0".into(),
"1".into(),
HashMap::from([("time".into(), 4.into())]),
),
(
"0".into(),
"2".into(),
HashMap::from([("time".into(), 5.into())]),
),
];

let mut medrecord = MedRecord::from_tuples(nodes, Some(edges), None).unwrap();

medrecord
.add_group("treatment".into(), Some(vec!["1".into()]), None)
.unwrap();
medrecord
.add_group("outcome".into(), Some(vec!["2".into()]), None)
.unwrap();

let nodes = medrecord.select_nodes(|node| {
let edges = node.incoming_edges();

edges.attribute("time").any(|value| value.less_than(6))
});

println!("\n{:?}", nodes.collect::<Vec<_>>());
}
}
8 changes: 8 additions & 0 deletions crates/medmodels-core/src/medrecord/querying/edges/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
mod operand;
mod operation;
mod selection;
mod values;

pub use operand::{EdgeOperand, EdgeValueOperand, EdgeValuesOperand};
pub(crate) use operation::EdgeValuesOperation;
pub use selection::EdgeSelection;
Loading

0 comments on commit 5df1278

Please sign in to comment.