From ef09e2e106a740ebf0f9e5c5b44b954a7e20ff44 Mon Sep 17 00:00:00 2001 From: Jakob Kraus <52459467+JabobKrauskopf@users.noreply.github.com> Date: Thu, 10 Oct 2024 11:12:19 +0200 Subject: [PATCH] refactor: rust query engine (#193) --- Cargo.lock | 31 +- Cargo.toml | 5 +- crates/medmodels-core/Cargo.toml | 4 +- crates/medmodels-core/src/errors/medrecord.rs | 3 + crates/medmodels-core/src/errors/mod.rs | 2 + .../src/medrecord/datatypes/attribute.rs | 223 ++- .../src/medrecord/datatypes/mod.rs | 72 +- .../src/medrecord/datatypes/value.rs | 201 +- .../src/medrecord/example_dataset/mod.rs | 11 +- .../src/medrecord/graph/edge.rs | 6 +- .../medmodels-core/src/medrecord/graph/mod.rs | 45 +- .../src/medrecord/graph/node.rs | 6 +- crates/medmodels-core/src/medrecord/mod.rs | 42 +- .../src/medrecord/querying/attributes/mod.rs | 132 ++ .../medrecord/querying/attributes/operand.rs | 874 +++++++++ .../querying/attributes/operation.rs | 1357 +++++++++++++ .../src/medrecord/querying/edges/mod.rs | 58 + .../src/medrecord/querying/edges/operand.rs | 655 +++++++ .../src/medrecord/querying/edges/operation.rs | 762 ++++++++ .../src/medrecord/querying/edges/selection.rs | 32 + .../src/medrecord/querying/mod.rs | 15 +- .../src/medrecord/querying/nodes/mod.rs | 68 + .../src/medrecord/querying/nodes/operand.rs | 732 +++++++ .../src/medrecord/querying/nodes/operation.rs | 971 +++++++++ .../src/medrecord/querying/nodes/selection.rs | 35 + .../querying/operation/edge_operation.rs | 475 ----- .../src/medrecord/querying/operation/mod.rs | 394 ---- .../querying/operation/node_operation.rs | 246 --- .../medrecord/querying/operation/operand.rs | 649 ------ .../src/medrecord/querying/selection.rs | 1741 ----------------- .../src/medrecord/querying/traits.rs | 21 + .../src/medrecord/querying/values/mod.rs | 185 ++ .../src/medrecord/querying/values/operand.rs | 590 ++++++ .../medrecord/querying/values/operation.rs | 934 +++++++++ .../src/medrecord/querying/wrapper.rs | 45 + crates/medmodels-core/src/medrecord/schema.rs | 2 - rustmodels/Cargo.toml | 4 +- rustmodels/src/medrecord/mod.rs | 2 +- 38 files changed, 7959 insertions(+), 3671 deletions(-) create mode 100644 crates/medmodels-core/src/medrecord/querying/attributes/mod.rs create mode 100644 crates/medmodels-core/src/medrecord/querying/attributes/operand.rs create mode 100644 crates/medmodels-core/src/medrecord/querying/attributes/operation.rs create mode 100644 crates/medmodels-core/src/medrecord/querying/edges/mod.rs create mode 100644 crates/medmodels-core/src/medrecord/querying/edges/operand.rs create mode 100644 crates/medmodels-core/src/medrecord/querying/edges/operation.rs create mode 100644 crates/medmodels-core/src/medrecord/querying/edges/selection.rs create mode 100644 crates/medmodels-core/src/medrecord/querying/nodes/mod.rs create mode 100644 crates/medmodels-core/src/medrecord/querying/nodes/operand.rs create mode 100644 crates/medmodels-core/src/medrecord/querying/nodes/operation.rs create mode 100644 crates/medmodels-core/src/medrecord/querying/nodes/selection.rs delete mode 100644 crates/medmodels-core/src/medrecord/querying/operation/edge_operation.rs delete mode 100644 crates/medmodels-core/src/medrecord/querying/operation/mod.rs delete mode 100644 crates/medmodels-core/src/medrecord/querying/operation/node_operation.rs delete mode 100644 crates/medmodels-core/src/medrecord/querying/operation/operand.rs delete mode 100644 crates/medmodels-core/src/medrecord/querying/selection.rs create mode 100644 crates/medmodels-core/src/medrecord/querying/traits.rs create mode 100644 crates/medmodels-core/src/medrecord/querying/values/mod.rs create mode 100644 crates/medmodels-core/src/medrecord/querying/values/operand.rs create mode 100644 crates/medmodels-core/src/medrecord/querying/values/operation.rs create mode 100644 crates/medmodels-core/src/medrecord/querying/wrapper.rs diff --git a/Cargo.lock b/Cargo.lock index 60095bba..280c2399 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -116,9 +116,9 @@ checksum = "d32a994c2b3ca201d9b263612a374263f05e7adde37c4707f693dcd375076d1f" [[package]] name = "bytemuck" -version = "1.14.3" +version = "1.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2ef034f05691a48569bd920a96c81b9d91bbad1ab5ac7c4616c1f6ef36cb79f" +checksum = "773d90827bc3feecfb67fab12e24de0749aad83c74b9504ecde46237b5cd24e2" dependencies = [ "bytemuck_derive", ] @@ -134,6 +134,12 @@ dependencies = [ "syn 2.0.50", ] +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + [[package]] name = "bytes" version = "1.5.0" @@ -427,6 +433,15 @@ version = "2.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e186cfbae8084e513daff4240b4797e342f988cecda4fb6c939150f96315fd8" +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.10" @@ -517,8 +532,10 @@ name = "medmodels-core" version = "0.1.2" dependencies = [ "chrono", + "itertools", "medmodels-utils", "polars", + "roaring", "ron", "serde", ] @@ -1335,6 +1352,16 @@ version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" +[[package]] +name = "roaring" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f4b84ba6e838ceb47b41de5194a60244fac43d9fe03b71dbe8c5a201081d6d1" +dependencies = [ + "bytemuck", + "byteorder", +] + [[package]] name = "ron" version = "0.8.1" diff --git a/Cargo.toml b/Cargo.toml index eac0a8c7..40154170 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,11 +13,8 @@ description = "Limebit MedModels Crate" [workspace.dependencies] hashbrown = { version = "0.14.5", features = ["serde"] } serde = { version = "1.0.203", features = ["derive"] } -ron = "0.8.1" -chrono = { version = "0.4.38", features = ["serde"] } -pyo3 = { version = "0.21.2", features = ["chrono"] } polars = { version = "0.40.0", features = ["polars-io"] } -pyo3-polars = "0.14.0" +chrono = { version = "0.4.38", features = ["serde"] } medmodels = { version = "0.1.2", path = "crates/medmodels" } medmodels-core = { version = "0.1.2", path = "crates/medmodels-core" } diff --git a/crates/medmodels-core/Cargo.toml b/crates/medmodels-core/Cargo.toml index 58097fcf..48225587 100644 --- a/crates/medmodels-core/Cargo.toml +++ b/crates/medmodels-core/Cargo.toml @@ -12,5 +12,7 @@ medmodels-utils = { workspace = true } polars = { workspace = true } serde = { workspace = true } -ron = { workspace = true } chrono = { workspace = true } +ron = "0.8.1" +roaring = "0.10.6" +itertools = "0.13.0" diff --git a/crates/medmodels-core/src/errors/medrecord.rs b/crates/medmodels-core/src/errors/medrecord.rs index f7afb230..3ad22a14 100644 --- a/crates/medmodels-core/src/errors/medrecord.rs +++ b/crates/medmodels-core/src/errors/medrecord.rs @@ -10,6 +10,7 @@ pub enum MedRecordError { ConversionError(String), AssertionError(String), SchemaError(String), + QueryError(String), } impl Error for MedRecordError { @@ -20,6 +21,7 @@ impl Error for MedRecordError { MedRecordError::ConversionError(message) => message, MedRecordError::AssertionError(message) => message, MedRecordError::SchemaError(message) => message, + MedRecordError::QueryError(message) => message, } } } @@ -32,6 +34,7 @@ impl Display for MedRecordError { Self::ConversionError(message) => write!(f, "ConversionError: {}", message), Self::AssertionError(message) => write!(f, "AssertionError: {}", message), Self::SchemaError(message) => write!(f, "SchemaError: {}", message), + Self::QueryError(message) => write!(f, "QueryError: {}", message), } } } diff --git a/crates/medmodels-core/src/errors/mod.rs b/crates/medmodels-core/src/errors/mod.rs index b0c37588..069281ca 100644 --- a/crates/medmodels-core/src/errors/mod.rs +++ b/crates/medmodels-core/src/errors/mod.rs @@ -14,6 +14,8 @@ impl From for MedRecordError { } } +pub type MedRecordResult = Result; + #[cfg(test)] mod test { use super::{GraphError, MedRecordError}; diff --git a/crates/medmodels-core/src/medrecord/datatypes/attribute.rs b/crates/medmodels-core/src/medrecord/datatypes/attribute.rs index f02f12d4..bdb2f12d 100644 --- a/crates/medmodels-core/src/medrecord/datatypes/attribute.rs +++ b/crates/medmodels-core/src/medrecord/datatypes/attribute.rs @@ -1,8 +1,16 @@ -use super::{Contains, EndsWith, MedRecordValue, StartsWith}; -use crate::errors::MedRecordError; +use super::{ + Abs, Contains, EndsWith, Lowercase, MedRecordValue, Mod, Pow, Slice, StartsWith, Trim, TrimEnd, + TrimStart, Uppercase, +}; +use crate::errors::{MedRecordError, MedRecordResult}; use medmodels_utils::implement_from_for_wrapper; use serde::{Deserialize, Serialize}; -use std::{cmp::Ordering, fmt::Display, hash::Hash}; +use std::{ + cmp::Ordering, + fmt::Display, + hash::Hash, + ops::{Add, Mul, Sub}, +}; #[derive(Debug, Clone, Serialize, Deserialize)] pub enum MedRecordAttribute { @@ -43,15 +51,6 @@ impl TryFrom for MedRecordAttribute { } } -impl Display for MedRecordAttribute { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::String(value) => write!(f, "{}", value), - Self::Int(value) => write!(f, "{}", value), - } - } -} - impl PartialEq for MedRecordAttribute { fn eq(&self, other: &Self) -> bool { match (self, other) { @@ -80,6 +79,140 @@ impl PartialOrd for MedRecordAttribute { } } +impl Display for MedRecordAttribute { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::String(value) => write!(f, "{}", value), + Self::Int(value) => write!(f, "{}", value), + } + } +} + +// TODO: Add tests +impl Add for MedRecordAttribute { + type Output = MedRecordResult; + + fn add(self, rhs: Self) -> Self::Output { + match (self, rhs) { + (MedRecordAttribute::String(value), MedRecordAttribute::String(rhs)) => { + Ok(MedRecordAttribute::String(value + rhs.as_str())) + } + (MedRecordAttribute::String(value), MedRecordAttribute::Int(rhs)) => Err( + MedRecordError::AssertionError(format!("Cannot add {} to {}", rhs, value)), + ), + (MedRecordAttribute::Int(value), MedRecordAttribute::String(rhs)) => Err( + MedRecordError::AssertionError(format!("Cannot add {} to {}", rhs, value)), + ), + (MedRecordAttribute::Int(value), MedRecordAttribute::Int(rhs)) => { + Ok(MedRecordAttribute::Int(value + rhs)) + } + } + } +} + +// TODO: Add tests +impl Sub for MedRecordAttribute { + type Output = MedRecordResult; + + fn sub(self, rhs: Self) -> Self::Output { + match (self, rhs) { + (MedRecordAttribute::String(value), MedRecordAttribute::String(rhs)) => Err( + MedRecordError::AssertionError(format!("Cannot subtract {} from {}", rhs, value)), + ), + (MedRecordAttribute::String(value), MedRecordAttribute::Int(rhs)) => Err( + MedRecordError::AssertionError(format!("Cannot subtract {} from {}", rhs, value)), + ), + (MedRecordAttribute::Int(value), MedRecordAttribute::String(rhs)) => Err( + MedRecordError::AssertionError(format!("Cannot subtract {} from {}", rhs, value)), + ), + (MedRecordAttribute::Int(value), MedRecordAttribute::Int(rhs)) => { + Ok(MedRecordAttribute::Int(value - rhs)) + } + } + } +} + +// TODO: Add tests +impl Mul for MedRecordAttribute { + type Output = MedRecordResult; + + fn mul(self, rhs: Self) -> Self::Output { + match (self, rhs) { + (MedRecordAttribute::String(value), MedRecordAttribute::String(rhs)) => Err( + MedRecordError::AssertionError(format!("Cannot multiply {} by {}", value, rhs)), + ), + (MedRecordAttribute::String(value), MedRecordAttribute::Int(rhs)) => Err( + MedRecordError::AssertionError(format!("Cannot multiply {} by {}", value, rhs)), + ), + (MedRecordAttribute::Int(value), MedRecordAttribute::String(rhs)) => Err( + MedRecordError::AssertionError(format!("Cannot multiply {} by {}", value, rhs)), + ), + (MedRecordAttribute::Int(value), MedRecordAttribute::Int(rhs)) => { + Ok(MedRecordAttribute::Int(value * rhs)) + } + } + } +} + +// TODO: Add tests +impl Pow for MedRecordAttribute { + fn pow(self, rhs: Self) -> MedRecordResult { + match (self, rhs) { + (MedRecordAttribute::String(value), MedRecordAttribute::String(rhs)) => { + Err(MedRecordError::AssertionError(format!( + "Cannot raise {} to the power of {}", + value, rhs + ))) + } + (MedRecordAttribute::String(value), MedRecordAttribute::Int(rhs)) => { + Err(MedRecordError::AssertionError(format!( + "Cannot raise {} to the power of {}", + value, rhs + ))) + } + (MedRecordAttribute::Int(value), MedRecordAttribute::String(rhs)) => { + Err(MedRecordError::AssertionError(format!( + "Cannot raise {} to the power of {}", + value, rhs + ))) + } + (MedRecordAttribute::Int(value), MedRecordAttribute::Int(rhs)) => { + Ok(MedRecordAttribute::Int(value.pow(rhs as u32))) + } + } + } +} + +// TODO: Add tests +impl Mod for MedRecordAttribute { + fn r#mod(self, rhs: Self) -> MedRecordResult { + match (self, rhs) { + (MedRecordAttribute::String(value), MedRecordAttribute::String(rhs)) => Err( + MedRecordError::AssertionError(format!("Cannot mod {} by {}", value, rhs)), + ), + (MedRecordAttribute::String(value), MedRecordAttribute::Int(rhs)) => Err( + MedRecordError::AssertionError(format!("Cannot mod {} by {}", value, rhs)), + ), + (MedRecordAttribute::Int(value), MedRecordAttribute::String(rhs)) => Err( + MedRecordError::AssertionError(format!("Cannot mod {} by {}", value, rhs)), + ), + (MedRecordAttribute::Int(value), MedRecordAttribute::Int(rhs)) => { + Ok(MedRecordAttribute::Int(value % rhs)) + } + } + } +} + +// TODO: Add tests +impl Abs for MedRecordAttribute { + fn abs(self) -> Self { + match self { + MedRecordAttribute::Int(value) => MedRecordAttribute::Int(value.abs()), + _ => self, + } + } +} + impl StartsWith for MedRecordAttribute { fn starts_with(&self, other: &Self) -> bool { match (self, other) { @@ -137,6 +270,72 @@ impl Contains for MedRecordAttribute { } } +// TODO: Add tests +impl Slice for MedRecordAttribute { + fn slice(self, range: std::ops::Range) -> Self { + match self { + MedRecordAttribute::String(value) => value[range].into(), + MedRecordAttribute::Int(value) => value.to_string()[range].into(), + } + } +} + +// TODO: Add tests +impl Trim for MedRecordAttribute { + fn trim(self) -> Self { + match self { + MedRecordAttribute::String(value) => { + MedRecordAttribute::String(value.trim().to_string()) + } + _ => self, + } + } +} + +// TODO: Add tests +impl TrimStart for MedRecordAttribute { + fn trim_start(self) -> Self { + match self { + MedRecordAttribute::String(value) => { + MedRecordAttribute::String(value.trim_start().to_string()) + } + _ => self, + } + } +} + +// TODO: Add tests +impl TrimEnd for MedRecordAttribute { + fn trim_end(self) -> Self { + match self { + MedRecordAttribute::String(value) => { + MedRecordAttribute::String(value.trim_end().to_string()) + } + _ => self, + } + } +} + +// TODO: Add tests +impl Lowercase for MedRecordAttribute { + fn lowercase(self) -> Self { + match self { + MedRecordAttribute::String(value) => MedRecordAttribute::String(value.to_lowercase()), + _ => self, + } + } +} + +// TODO: Add tests +impl Uppercase for MedRecordAttribute { + fn uppercase(self) -> Self { + match self { + MedRecordAttribute::String(value) => MedRecordAttribute::String(value.to_uppercase()), + _ => self, + } + } +} + #[cfg(test)] mod test { use super::MedRecordAttribute; diff --git a/crates/medmodels-core/src/medrecord/datatypes/mod.rs b/crates/medmodels-core/src/medrecord/datatypes/mod.rs index 0beca37e..ada0f6c0 100644 --- a/crates/medmodels-core/src/medrecord/datatypes/mod.rs +++ b/crates/medmodels-core/src/medrecord/datatypes/mod.rs @@ -2,6 +2,7 @@ mod attribute; mod value; pub use self::{attribute::MedRecordAttribute, value::MedRecordValue}; +use super::EdgeIndex; use crate::errors::MedRecordError; use serde::{Deserialize, Serialize}; use std::{fmt::Display, ops::Range}; @@ -51,6 +52,24 @@ impl From<&MedRecordValue> for DataType { } } +impl From for DataType { + fn from(value: MedRecordAttribute) -> Self { + match value { + MedRecordAttribute::String(_) => DataType::String, + MedRecordAttribute::Int(_) => DataType::Int, + } + } +} + +impl From<&MedRecordAttribute> for DataType { + fn from(value: &MedRecordAttribute) -> Self { + match value { + MedRecordAttribute::String(_) => DataType::String, + MedRecordAttribute::Int(_) => DataType::Int, + } + } +} + impl PartialEq for DataType { fn eq(&self, other: &Self) -> bool { match (self, other) { @@ -126,28 +145,52 @@ impl DataType { } } -pub trait Pow: Sized { - fn pow(self, exp: Self) -> Result; -} - -pub trait Mod: Sized { - fn r#mod(self, other: Self) -> Result; -} - pub trait StartsWith { fn starts_with(&self, other: &Self) -> bool; } +// TODO: Add tests +impl StartsWith for EdgeIndex { + fn starts_with(&self, other: &Self) -> bool { + self.to_string().starts_with(&other.to_string()) + } +} + pub trait EndsWith { fn ends_with(&self, other: &Self) -> bool; } +// TODO: Add tests +impl EndsWith for EdgeIndex { + fn ends_with(&self, other: &Self) -> bool { + self.to_string().ends_with(&other.to_string()) + } +} + pub trait Contains { fn contains(&self, other: &Self) -> bool; } -pub trait PartialNeq: PartialEq { - fn neq(&self, other: &Self) -> bool; +// TODO: Add tests +impl Contains for EdgeIndex { + fn contains(&self, other: &Self) -> bool { + self.to_string().contains(&other.to_string()) + } +} + +pub trait Pow: Sized { + fn pow(self, exp: Self) -> Result; +} + +pub trait Mod: Sized { + fn r#mod(self, other: Self) -> Result; +} + +// TODO: Add tests +impl Mod for EdgeIndex { + fn r#mod(self, other: Self) -> Result { + Ok(self % other) + } } pub trait Round { @@ -194,15 +237,6 @@ pub trait Slice { fn slice(self, range: Range) -> Self; } -impl PartialNeq for T -where - T: PartialOrd, -{ - fn neq(&self, other: &Self) -> bool { - self != other - } -} - #[cfg(test)] mod test { use super::{DataType, MedRecordValue}; diff --git a/crates/medmodels-core/src/medrecord/datatypes/value.rs b/crates/medmodels-core/src/medrecord/datatypes/value.rs index 792d879d..f3995102 100644 --- a/crates/medmodels-core/src/medrecord/datatypes/value.rs +++ b/crates/medmodels-core/src/medrecord/datatypes/value.rs @@ -3,7 +3,7 @@ use super::{ Trim, TrimEnd, TrimStart, Uppercase, }; use crate::errors::MedRecordError; -use chrono::NaiveDateTime; +use chrono::{DateTime, NaiveDateTime}; use medmodels_utils::implement_from_for_wrapper; use serde::{Deserialize, Serialize}; use std::{ @@ -210,9 +210,17 @@ impl Add for MedRecordValue { (MedRecordValue::DateTime(value), MedRecordValue::Bool(rhs)) => Err( MedRecordError::AssertionError(format!("Cannot add {} to {}", rhs, value)), ), - (MedRecordValue::DateTime(value), MedRecordValue::DateTime(rhs)) => Err( - MedRecordError::AssertionError(format!("Cannot add {} to {}", rhs, value)), - ), + (MedRecordValue::DateTime(value), MedRecordValue::DateTime(rhs)) => { + Ok(DateTime::from_timestamp( + value.and_utc().timestamp() + rhs.and_utc().timestamp(), + 0, + ) + .ok_or(MedRecordError::AssertionError( + "Invalid timestamp".to_string(), + ))? + .naive_utc() + .into()) + } (MedRecordValue::DateTime(value), MedRecordValue::Null) => Err( MedRecordError::AssertionError(format!("Cannot add None to {}", value)), ), @@ -327,9 +335,17 @@ impl Sub for MedRecordValue { (MedRecordValue::DateTime(value), MedRecordValue::Bool(rhs)) => Err( MedRecordError::AssertionError(format!("Cannot subtract {} from {}", rhs, value)), ), - (MedRecordValue::DateTime(value), MedRecordValue::DateTime(rhs)) => Err( - MedRecordError::AssertionError(format!("Cannot subtract {} from {}", rhs, value)), - ), + (MedRecordValue::DateTime(value), MedRecordValue::DateTime(rhs)) => { + Ok(DateTime::from_timestamp( + value.and_utc().timestamp() - rhs.and_utc().timestamp(), + 0, + ) + .ok_or(MedRecordError::AssertionError( + "Invalid timestamp".to_string(), + ))? + .naive_utc() + .into()) + } (MedRecordValue::DateTime(value), MedRecordValue::Null) => Err( MedRecordError::AssertionError(format!("Cannot subtract None from {}", value)), ), @@ -621,9 +637,17 @@ impl Div for MedRecordValue { (MedRecordValue::DateTime(value), MedRecordValue::String(other)) => Err( MedRecordError::AssertionError(format!("Cannot divide {} by {}", value, other)), ), - (MedRecordValue::DateTime(value), MedRecordValue::Int(other)) => Err( - MedRecordError::AssertionError(format!("Cannot divide {} by {}", value, other)), - ), + (MedRecordValue::DateTime(value), MedRecordValue::Int(other)) => { + Ok(DateTime::from_timestamp( + (value.and_utc().timestamp() as f64 / other as f64).floor() as i64, + 0, + ) + .ok_or(MedRecordError::AssertionError( + "Invalid timestamp".to_string(), + ))? + .naive_utc() + .into()) + } (MedRecordValue::DateTime(value), MedRecordValue::Float(other)) => Err( MedRecordError::AssertionError(format!("Cannot divide {} by {}", value, other)), ), @@ -966,6 +990,53 @@ impl Mod for MedRecordValue { } } +impl Round for MedRecordValue { + fn round(self) -> Self { + match self { + MedRecordValue::Float(value) => MedRecordValue::Float(value.round()), + _ => self, + } + } +} + +impl Ceil for MedRecordValue { + fn ceil(self) -> Self { + match self { + MedRecordValue::Float(value) => MedRecordValue::Float(value.ceil()), + _ => self, + } + } +} + +impl Floor for MedRecordValue { + fn floor(self) -> Self { + match self { + MedRecordValue::Float(value) => MedRecordValue::Float(value.floor()), + _ => self, + } + } +} + +impl Abs for MedRecordValue { + fn abs(self) -> Self { + match self { + MedRecordValue::Int(value) => MedRecordValue::Int(value.abs()), + MedRecordValue::Float(value) => MedRecordValue::Float(value.abs()), + _ => self, + } + } +} + +impl Sqrt for MedRecordValue { + fn sqrt(self) -> Self { + match self { + MedRecordValue::Int(value) => MedRecordValue::Float((value as f64).sqrt()), + MedRecordValue::Float(value) => MedRecordValue::Float(value.sqrt()), + _ => self, + } + } +} + impl StartsWith for MedRecordValue { fn starts_with(&self, other: &Self) -> bool { match (self, other) { @@ -1081,53 +1152,6 @@ impl Slice for MedRecordValue { } } -impl Round for MedRecordValue { - fn round(self) -> Self { - match self { - MedRecordValue::Float(value) => MedRecordValue::Float(value.round()), - _ => self, - } - } -} - -impl Ceil for MedRecordValue { - fn ceil(self) -> Self { - match self { - MedRecordValue::Float(value) => MedRecordValue::Float(value.ceil()), - _ => self, - } - } -} - -impl Floor for MedRecordValue { - fn floor(self) -> Self { - match self { - MedRecordValue::Float(value) => MedRecordValue::Float(value.floor()), - _ => self, - } - } -} - -impl Abs for MedRecordValue { - fn abs(self) -> Self { - match self { - MedRecordValue::Int(value) => MedRecordValue::Int(value.abs()), - MedRecordValue::Float(value) => MedRecordValue::Float(value.abs()), - _ => self, - } - } -} - -impl Sqrt for MedRecordValue { - fn sqrt(self) -> Self { - match self { - MedRecordValue::Int(value) => MedRecordValue::Float((value as f64).sqrt()), - MedRecordValue::Float(value) => MedRecordValue::Float(value.sqrt()), - _ => self, - } - } -} - impl Trim for MedRecordValue { fn trim(self) -> Self { match self { @@ -1183,7 +1207,7 @@ mod test { Uppercase, }, }; - use chrono::NaiveDateTime; + use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime}; #[test] fn test_default() { @@ -1669,9 +1693,23 @@ mod test { (MedRecordValue::DateTime(NaiveDateTime::MIN) + MedRecordValue::Bool(false)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_))) ); - assert!((MedRecordValue::DateTime(NaiveDateTime::MIN) - + MedRecordValue::DateTime(NaiveDateTime::MIN)) - .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_)))); + assert_eq!( + MedRecordValue::DateTime( + NaiveDate::from_ymd_opt(1970, 1, 4) + .unwrap() + .and_time(NaiveTime::MIN) + ), + (MedRecordValue::DateTime( + NaiveDate::from_ymd_opt(1970, 1, 2) + .unwrap() + .and_time(NaiveTime::MIN) + ) + MedRecordValue::DateTime( + NaiveDate::from_ymd_opt(1970, 1, 3) + .unwrap() + .and_time(NaiveTime::MIN) + )) + .unwrap() + ); assert!( (MedRecordValue::DateTime(NaiveDateTime::MIN) + MedRecordValue::Null) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_))) @@ -1794,9 +1832,12 @@ mod test { (MedRecordValue::DateTime(NaiveDateTime::MIN) - MedRecordValue::Bool(false)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_))) ); - assert!((MedRecordValue::DateTime(NaiveDateTime::MIN) - - MedRecordValue::DateTime(NaiveDateTime::MIN)) - .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_)))); + assert_eq!( + MedRecordValue::DateTime(DateTime::from_timestamp(0, 0).unwrap().naive_utc()), + (MedRecordValue::DateTime(NaiveDateTime::MAX) + - MedRecordValue::DateTime(NaiveDateTime::MAX)) + .unwrap() + ); assert!( (MedRecordValue::DateTime(NaiveDateTime::MIN) - MedRecordValue::Null) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_))) @@ -1951,15 +1992,15 @@ mod test { / MedRecordValue::String("value".to_string())) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_)))); assert!( - (MedRecordValue::String("value".to_string()) / MedRecordValue::Int(0)) + (MedRecordValue::String("value".to_string()) / MedRecordValue::Int(1)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_))) ); assert!( - (MedRecordValue::String("value".to_string()) / MedRecordValue::Float(0_f64)) + (MedRecordValue::String("value".to_string()) / MedRecordValue::Float(1_f64)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_))) ); assert!( - (MedRecordValue::String("value".to_string()) / MedRecordValue::Bool(false)) + (MedRecordValue::String("value".to_string()) / MedRecordValue::Bool(true)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_))) ); assert!((MedRecordValue::String("value".to_string()) @@ -1982,7 +2023,7 @@ mod test { MedRecordValue::Float(1_f64), (MedRecordValue::Int(5) / MedRecordValue::Float(5_f64)).unwrap() ); - assert!((MedRecordValue::Int(0) / MedRecordValue::Bool(false)) + assert!((MedRecordValue::Int(0) / MedRecordValue::Bool(true)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_)))); assert!( (MedRecordValue::Int(0) / MedRecordValue::DateTime(NaiveDateTime::MIN)) @@ -2003,7 +2044,7 @@ mod test { MedRecordValue::Float(1_f64), (MedRecordValue::Float(5_f64) / MedRecordValue::Float(5_f64)).unwrap() ); - assert!((MedRecordValue::Float(0_f64) / MedRecordValue::Bool(false)) + assert!((MedRecordValue::Float(0_f64) / MedRecordValue::Bool(true)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_)))); assert!( (MedRecordValue::Float(0_f64) / MedRecordValue::DateTime(NaiveDateTime::MIN)) @@ -2016,11 +2057,11 @@ mod test { (MedRecordValue::Bool(false) / MedRecordValue::String("value".to_string())) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_))) ); - assert!((MedRecordValue::Bool(false) / MedRecordValue::Int(0)) + assert!((MedRecordValue::Bool(false) / MedRecordValue::Int(1)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_)))); - assert!((MedRecordValue::Bool(false) / MedRecordValue::Float(0_f64)) + assert!((MedRecordValue::Bool(false) / MedRecordValue::Float(1_f64)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_)))); - assert!((MedRecordValue::Bool(false) / MedRecordValue::Bool(false)) + assert!((MedRecordValue::Bool(false) / MedRecordValue::Bool(true)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_)))); assert!( (MedRecordValue::Bool(false) / MedRecordValue::DateTime(NaiveDateTime::MIN)) @@ -2032,16 +2073,16 @@ mod test { assert!((MedRecordValue::DateTime(NaiveDateTime::MIN) / MedRecordValue::String("value".to_string())) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_)))); - assert!( - (MedRecordValue::DateTime(NaiveDateTime::MIN) / MedRecordValue::Int(0)) - .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_))) + assert_eq!( + MedRecordValue::DateTime(NaiveDateTime::MIN), + (MedRecordValue::DateTime(NaiveDateTime::MIN) / MedRecordValue::Int(1)).unwrap() ); assert!( - (MedRecordValue::DateTime(NaiveDateTime::MIN) / MedRecordValue::Float(0_f64)) + (MedRecordValue::DateTime(NaiveDateTime::MIN) / MedRecordValue::Float(1_f64)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_))) ); assert!( - (MedRecordValue::DateTime(NaiveDateTime::MIN) / MedRecordValue::Bool(false)) + (MedRecordValue::DateTime(NaiveDateTime::MIN) / MedRecordValue::Bool(true)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_))) ); assert!((MedRecordValue::DateTime(NaiveDateTime::MIN) @@ -2056,11 +2097,11 @@ mod test { (MedRecordValue::Null / MedRecordValue::String("value".to_string())) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_))) ); - assert!((MedRecordValue::Null / MedRecordValue::Int(0)) + assert!((MedRecordValue::Null / MedRecordValue::Int(1)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_)))); - assert!((MedRecordValue::Null / MedRecordValue::Float(0_f64)) + assert!((MedRecordValue::Null / MedRecordValue::Float(1_f64)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_)))); - assert!((MedRecordValue::Null / MedRecordValue::Bool(false)) + assert!((MedRecordValue::Null / MedRecordValue::Bool(true)) .is_err_and(|e| matches!(e, MedRecordError::AssertionError(_)))); assert!( (MedRecordValue::Null / MedRecordValue::DateTime(NaiveDateTime::MIN)) diff --git a/crates/medmodels-core/src/medrecord/example_dataset/mod.rs b/crates/medmodels-core/src/medrecord/example_dataset/mod.rs index e4879307..2a0f3354 100644 --- a/crates/medmodels-core/src/medrecord/example_dataset/mod.rs +++ b/crates/medmodels-core/src/medrecord/example_dataset/mod.rs @@ -71,7 +71,7 @@ impl MedRecord { .into_reader_with_file_handle(cursor) .finish() .expect("DataFrame can be built"); - let patient_diagnosis_ids = (0..patient_diagnosis.height()).collect::>(); + let patient_diagnosis_ids = (0..patient_diagnosis.height() as u32).collect::>(); let cursor = Cursor::new(PATIENT_DRUG); let patient_drug = CsvReadOptions::default() @@ -79,8 +79,8 @@ impl MedRecord { .into_reader_with_file_handle(cursor) .finish() .expect("DataFrame can be built"); - let patient_drug_ids = (patient_diagnosis.height() - ..patient_diagnosis.height() + patient_drug.height()) + let patient_drug_ids = (patient_diagnosis.height() as u32 + ..(patient_diagnosis.height() + patient_drug.height()) as u32) .collect::>(); let cursor = Cursor::new(PATIENT_PROCEDURE); @@ -89,8 +89,9 @@ impl MedRecord { .into_reader_with_file_handle(cursor) .finish() .expect("DataFrame can be built"); - let patient_procedure_ids = (patient_diagnosis.height() + patient_drug.height() - ..patient_diagnosis.height() + patient_drug.height() + patient_procedure.height()) + let patient_procedure_ids = ((patient_diagnosis.height() + patient_drug.height()) as u32 + ..(patient_diagnosis.height() + patient_drug.height() + patient_procedure.height()) + as u32) .collect::>(); let mut medrecord = Self::from_dataframes( diff --git a/crates/medmodels-core/src/medrecord/graph/edge.rs b/crates/medmodels-core/src/medrecord/graph/edge.rs index a45b6c4d..36b790d8 100644 --- a/crates/medmodels-core/src/medrecord/graph/edge.rs +++ b/crates/medmodels-core/src/medrecord/graph/edge.rs @@ -3,9 +3,9 @@ use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Edge { - pub attributes: Attributes, - pub(super) source_node_index: NodeIndex, - pub(super) target_node_index: NodeIndex, + pub(crate) attributes: Attributes, + pub(crate) source_node_index: NodeIndex, + pub(crate) target_node_index: NodeIndex, } impl Edge { diff --git a/crates/medmodels-core/src/medrecord/graph/mod.rs b/crates/medmodels-core/src/medrecord/graph/mod.rs index 885aa3cc..9d3ebb4f 100644 --- a/crates/medmodels-core/src/medrecord/graph/mod.rs +++ b/crates/medmodels-core/src/medrecord/graph/mod.rs @@ -9,18 +9,18 @@ use node::Node; use serde::{Deserialize, Serialize}; use std::{ collections::{HashMap, HashSet}, - sync::atomic::AtomicUsize, + sync::atomic::AtomicU32, }; pub type NodeIndex = MedRecordAttribute; -pub type EdgeIndex = usize; +pub type EdgeIndex = u32; pub type Attributes = HashMap; #[derive(Serialize, Deserialize, Debug)] pub(super) struct Graph { pub(crate) nodes: MrHashMap, pub(crate) edges: MrHashMap, - edge_index_counter: AtomicUsize, + edge_index_counter: AtomicU32, } #[allow(dead_code)] @@ -29,7 +29,7 @@ impl Graph { Self { nodes: MrHashMap::new(), edges: MrHashMap::new(), - edge_index_counter: AtomicUsize::new(0), + edge_index_counter: AtomicU32::new(0), } } @@ -37,7 +37,7 @@ impl Graph { Self { nodes: MrHashMap::with_capacity(node_capacity), edges: MrHashMap::with_capacity(edge_capacity), - edge_index_counter: AtomicUsize::new(0), + edge_index_counter: AtomicU32::new(0), } } @@ -45,13 +45,13 @@ impl Graph { self.nodes.clear(); self.edges.clear(); - self.edge_index_counter = AtomicUsize::new(0); + self.edge_index_counter = AtomicU32::new(0); } pub fn clear_edges(&mut self) { self.edges.clear(); - self.edge_index_counter = AtomicUsize::new(0); + self.edge_index_counter = AtomicU32::new(0); } pub fn node_count(&self) -> usize { @@ -338,7 +338,7 @@ impl Graph { self.edges.contains_key(edge_index) } - pub fn neighbors( + pub fn neighbors_outgoing( &self, node_index: &NodeIndex, ) -> Result, GraphError> { @@ -360,6 +360,29 @@ impl Graph { })) } + // TODO: Add tests + pub fn neighbors_incoming( + &self, + node_index: &NodeIndex, + ) -> Result, GraphError> { + Ok(self + .nodes + .get(node_index) + .ok_or(GraphError::IndexError(format!( + "Cannot find node with index {}", + node_index + )))? + .incoming_edge_indices + .iter() + .map(|edge_index| { + &self + .edges + .get(edge_index) + .expect("Edge must exist") + .source_node_index + })) + } + pub fn neighbors_undirected( &self, node_index: &NodeIndex, @@ -890,7 +913,7 @@ mod test { fn test_neighbors() { let graph = create_graph(); - let neighbors = graph.neighbors(&"0".into()).unwrap(); + let neighbors = graph.neighbors_outgoing(&"0".into()).unwrap(); assert_eq!(2, neighbors.count()); } @@ -900,7 +923,7 @@ mod test { let graph = create_graph(); assert!(graph - .neighbors(&"50".into()) + .neighbors_outgoing(&"50".into()) .is_err_and(|e| matches!(e, GraphError::IndexError(_)))); } @@ -908,7 +931,7 @@ mod test { fn test_neighbors_undirected() { let graph = create_graph(); - let neighbors = graph.neighbors(&"2".into()).unwrap(); + let neighbors = graph.neighbors_outgoing(&"2".into()).unwrap(); assert_eq!(0, neighbors.count()); let neighbors = graph.neighbors_undirected(&"2".into()).unwrap(); diff --git a/crates/medmodels-core/src/medrecord/graph/node.rs b/crates/medmodels-core/src/medrecord/graph/node.rs index 9af16851..4d90ee0f 100644 --- a/crates/medmodels-core/src/medrecord/graph/node.rs +++ b/crates/medmodels-core/src/medrecord/graph/node.rs @@ -4,9 +4,9 @@ use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Node { - pub attributes: Attributes, - pub(super) outgoing_edge_indices: MrHashSet, - pub(super) incoming_edge_indices: MrHashSet, + pub(crate) attributes: Attributes, + pub(crate) outgoing_edge_indices: MrHashSet, + pub(crate) incoming_edge_indices: MrHashSet, } impl Node { diff --git a/crates/medmodels-core/src/medrecord/mod.rs b/crates/medmodels-core/src/medrecord/mod.rs index 76af2add..aaea6808 100644 --- a/crates/medmodels-core/src/medrecord/mod.rs +++ b/crates/medmodels-core/src/medrecord/mod.rs @@ -11,9 +11,9 @@ pub use self::{ graph::{Attributes, EdgeIndex, NodeIndex}, group_mapping::Group, querying::{ - edge, node, ArithmeticOperation, EdgeAttributeOperand, EdgeIndexOperand, EdgeOperand, - EdgeOperation, NodeAttributeOperand, NodeIndexOperand, NodeOperand, NodeOperation, - TransformationOperation, ValueOperand, + edges::EdgeOperand, + nodes::NodeOperand, + wrapper::{CardinalityWrapper, Wrapper}, }, schema::{AttributeDataType, AttributeType, GroupSchema, Schema}, }; @@ -22,7 +22,7 @@ use ::polars::frame::DataFrame; use graph::Graph; use group_mapping::GroupMapping; use polars::{dataframe_to_edges, dataframe_to_nodes}; -use querying::{EdgeSelection, NodeSelection}; +use querying::{edges::EdgeSelection, nodes::NodeSelection}; use serde::{Deserialize, Serialize}; use std::{fs, mem, path::Path}; @@ -683,12 +683,22 @@ impl MedRecord { self.group_mapping.contains_group(group) } - pub fn neighbors( + pub fn neighbors_outgoing( &self, node_index: &NodeIndex, ) -> Result, MedRecordError> { self.graph - .neighbors(node_index) + .neighbors_outgoing(node_index) + .map_err(MedRecordError::from) + } + + // TODO: Add tests + pub fn neighbors_incoming( + &self, + node_index: &NodeIndex, + ) -> Result, MedRecordError> { + self.graph + .neighbors_incoming(node_index) .map_err(MedRecordError::from) } @@ -706,12 +716,18 @@ impl MedRecord { self.group_mapping.clear(); } - pub fn select_nodes(&self, operation: NodeOperation) -> NodeSelection { - NodeSelection::new(self, operation) + pub fn select_nodes(&self, query: Q) -> NodeSelection + where + Q: FnOnce(&mut Wrapper), + { + NodeSelection::new(self, query) } - pub fn select_edges(&self, operation: EdgeOperation) -> EdgeSelection { - EdgeSelection::new(self, operation) + pub fn select_edges(&self, query: Q) -> EdgeSelection + where + Q: FnOnce(&mut Wrapper), + { + EdgeSelection::new(self, query) } } @@ -1844,7 +1860,7 @@ mod test { fn test_neighbors() { let medrecord = create_medrecord(); - let neighbors = medrecord.neighbors(&"0".into()).unwrap(); + let neighbors = medrecord.neighbors_outgoing(&"0".into()).unwrap(); assert_eq!(2, neighbors.count()); } @@ -1855,7 +1871,7 @@ mod test { // Querying neighbors of a non-existing node sohuld fail assert!(medrecord - .neighbors(&"0".into()) + .neighbors_outgoing(&"0".into()) .is_err_and(|e| matches!(e, MedRecordError::IndexError(_)))); } @@ -1863,7 +1879,7 @@ mod test { fn test_neighbors_undirected() { let medrecord = create_medrecord(); - let neighbors = medrecord.neighbors(&"2".into()).unwrap(); + let neighbors = medrecord.neighbors_outgoing(&"2".into()).unwrap(); assert_eq!(0, neighbors.count()); let neighbors = medrecord.neighbors_undirected(&"2".into()).unwrap(); diff --git a/crates/medmodels-core/src/medrecord/querying/attributes/mod.rs b/crates/medmodels-core/src/medrecord/querying/attributes/mod.rs new file mode 100644 index 00000000..d16fcabd --- /dev/null +++ b/crates/medmodels-core/src/medrecord/querying/attributes/mod.rs @@ -0,0 +1,132 @@ +mod operand; +mod operation; + +use super::{ + edges::{EdgeOperand, EdgeOperation}, + nodes::{NodeOperand, NodeOperation}, + BoxedIterator, +}; +use crate::{ + errors::MedRecordResult, + medrecord::{Attributes, EdgeIndex, MedRecordAttribute, NodeIndex}, + MedRecord, +}; +pub use operand::{AttributesTreeOperand, MultipleAttributesOperand}; +pub use operation::{AttributesTreeOperation, MultipleAttributesOperation}; +use std::fmt::Display; + +#[derive(Debug, Clone)] +pub enum SingleKind { + Max, + Min, + Count, + Sum, + First, + Last, +} + +#[derive(Debug, Clone)] +pub enum MultipleKind { + Max, + Min, + Count, + Sum, + First, + Last, +} + +#[derive(Debug, Clone)] +pub enum SingleComparisonKind { + GreaterThan, + GreaterThanOrEqualTo, + LessThan, + LessThanOrEqualTo, + EqualTo, + NotEqualTo, + StartsWith, + EndsWith, + Contains, +} + +#[derive(Debug, Clone)] +pub enum MultipleComparisonKind { + IsIn, + IsNotIn, +} + +#[derive(Debug, Clone)] +pub enum BinaryArithmeticKind { + Add, + Sub, + Mul, + Pow, + Mod, +} + +impl Display for BinaryArithmeticKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + BinaryArithmeticKind::Add => write!(f, "add"), + BinaryArithmeticKind::Sub => write!(f, "sub"), + BinaryArithmeticKind::Mul => write!(f, "mul"), + BinaryArithmeticKind::Pow => write!(f, "pow"), + BinaryArithmeticKind::Mod => write!(f, "mod"), + } + } +} + +#[derive(Debug, Clone)] +pub enum UnaryArithmeticKind { + Abs, + Trim, + TrimStart, + TrimEnd, + Lowercase, + Uppercase, +} + +pub(crate) trait GetAttributes { + fn get_attributes<'a>(&'a self, medrecord: &'a MedRecord) -> MedRecordResult<&'a Attributes>; +} + +impl GetAttributes for NodeIndex { + fn get_attributes<'a>(&'a self, medrecord: &'a MedRecord) -> MedRecordResult<&'a Attributes> { + medrecord.node_attributes(self) + } +} + +impl GetAttributes for EdgeIndex { + fn get_attributes<'a>(&'a self, medrecord: &'a MedRecord) -> MedRecordResult<&'a Attributes> { + medrecord.edge_attributes(self) + } +} + +#[derive(Debug, Clone)] +pub enum Context { + NodeOperand(NodeOperand), + EdgeOperand(EdgeOperand), +} + +impl Context { + pub(crate) fn get_attributes<'a>( + &self, + medrecord: &'a MedRecord, + ) -> MedRecordResult>> { + Ok(match self { + Self::NodeOperand(node_operand) => { + let node_indices = node_operand.evaluate(medrecord)?; + + Box::new( + NodeOperation::get_attributes(medrecord, node_indices).map(|(_, value)| value), + ) + } + Self::EdgeOperand(edge_operand) => { + let edge_indices = edge_operand.evaluate(medrecord)?; + + Box::new( + EdgeOperation::get_attributes(medrecord, edge_indices).map(|(_, value)| value), + ) + } + }) + } +} diff --git a/crates/medmodels-core/src/medrecord/querying/attributes/operand.rs b/crates/medmodels-core/src/medrecord/querying/attributes/operand.rs new file mode 100644 index 00000000..83af4393 --- /dev/null +++ b/crates/medmodels-core/src/medrecord/querying/attributes/operand.rs @@ -0,0 +1,874 @@ +use super::{ + operation::{AttributesTreeOperation, MultipleAttributesOperation, SingleAttributeOperation}, + BinaryArithmeticKind, Context, GetAttributes, MultipleComparisonKind, MultipleKind, + SingleComparisonKind, SingleKind, UnaryArithmeticKind, +}; +use crate::{ + errors::MedRecordResult, + medrecord::{ + querying::{ + traits::{DeepClone, ReadWriteOrPanic}, + values::{self, MultipleValuesOperand}, + BoxedIterator, + }, + MedRecordAttribute, Wrapper, + }, + MedRecord, +}; +use std::{fmt::Display, hash::Hash}; + +macro_rules! implement_attributes_operation { + ($name:ident, $variant:ident) => { + pub fn $name(&mut self) -> Wrapper { + let operand = Wrapper::::new( + self.deep_clone(), + MultipleKind::$variant, + ); + + self.operations + .push(AttributesTreeOperation::AttributesOperation { + operand: operand.clone(), + }); + + operand + } + }; +} + +macro_rules! implement_attribute_operation { + ($name:ident, $variant:ident) => { + pub fn $name(&mut self) -> Wrapper { + let operand = + Wrapper::::new(self.deep_clone(), SingleKind::$variant); + + self.operations + .push(MultipleAttributesOperation::AttributeOperation { + operand: operand.clone(), + }); + + operand + } + }; +} + +macro_rules! implement_single_attribute_comparison_operation { + ($name:ident, $operation:ident, $kind:ident) => { + pub fn $name>(&mut self, attribute: V) { + self.operations + .push($operation::SingleAttributeComparisonOperation { + operand: attribute.into(), + kind: SingleComparisonKind::$kind, + }); + } + }; +} + +macro_rules! implement_binary_arithmetic_operation { + ($name:ident, $operation:ident, $kind:ident) => { + pub fn $name>(&mut self, attribute: V) { + self.operations.push($operation::BinaryArithmeticOpration { + operand: attribute.into(), + kind: BinaryArithmeticKind::$kind, + }); + } + }; +} + +macro_rules! implement_unary_arithmetic_operation { + ($name:ident, $operation:ident, $kind:ident) => { + pub fn $name(&mut self) { + self.operations.push($operation::UnaryArithmeticOperation { + kind: UnaryArithmeticKind::$kind, + }); + } + }; +} + +macro_rules! implement_assertion_operation { + ($name:ident, $operation:expr) => { + pub fn $name(&mut self) { + self.operations.push($operation); + } + }; +} + +macro_rules! implement_wrapper_operand { + ($name:ident) => { + pub fn $name(&self) { + self.0.write_or_panic().$name() + } + }; +} + +macro_rules! implement_wrapper_operand_with_return { + ($name:ident, $return_operand:ident) => { + pub fn $name(&self) -> Wrapper<$return_operand> { + self.0.write_or_panic().$name() + } + }; +} + +macro_rules! implement_wrapper_operand_with_argument { + ($name:ident, $attribute_type:ty) => { + pub fn $name(&self, attribute: $attribute_type) { + self.0.write_or_panic().$name(attribute) + } + }; +} + +#[derive(Debug, Clone)] +pub enum SingleAttributeComparisonOperand { + Operand(SingleAttributeOperand), + Attribute(MedRecordAttribute), +} + +impl DeepClone for SingleAttributeComparisonOperand { + fn deep_clone(&self) -> Self { + match self { + Self::Operand(operand) => Self::Operand(operand.deep_clone()), + Self::Attribute(attribute) => Self::Attribute(attribute.clone()), + } + } +} + +impl From> for SingleAttributeComparisonOperand { + fn from(value: Wrapper) -> Self { + Self::Operand(value.0.read_or_panic().deep_clone()) + } +} + +impl From<&Wrapper> for SingleAttributeComparisonOperand { + fn from(value: &Wrapper) -> Self { + Self::Operand(value.0.read_or_panic().deep_clone()) + } +} + +impl> From for SingleAttributeComparisonOperand { + fn from(value: V) -> Self { + Self::Attribute(value.into()) + } +} + +#[derive(Debug, Clone)] +pub enum MultipleAttributesComparisonOperand { + Operand(MultipleAttributesOperand), + Attributes(Vec), +} + +impl DeepClone for MultipleAttributesComparisonOperand { + fn deep_clone(&self) -> Self { + match self { + Self::Operand(operand) => Self::Operand(operand.deep_clone()), + Self::Attributes(attribute) => Self::Attributes(attribute.clone()), + } + } +} + +impl From> for MultipleAttributesComparisonOperand { + fn from(value: Wrapper) -> Self { + Self::Operand(value.0.read_or_panic().deep_clone()) + } +} + +impl From<&Wrapper> for MultipleAttributesComparisonOperand { + fn from(value: &Wrapper) -> Self { + Self::Operand(value.0.read_or_panic().deep_clone()) + } +} + +impl> From> for MultipleAttributesComparisonOperand { + fn from(value: Vec) -> Self { + Self::Attributes(value.into_iter().map(Into::into).collect()) + } +} + +impl + Clone, const N: usize> From<[V; N]> + for MultipleAttributesComparisonOperand +{ + fn from(value: [V; N]) -> Self { + value.to_vec().into() + } +} + +#[derive(Debug, Clone)] +pub struct AttributesTreeOperand { + pub(crate) context: Context, + operations: Vec, +} + +impl DeepClone for AttributesTreeOperand { + fn deep_clone(&self) -> Self { + Self { + context: self.context.clone(), + operations: self.operations.iter().map(DeepClone::deep_clone).collect(), + } + } +} + +impl AttributesTreeOperand { + pub(crate) fn new(context: Context) -> Self { + Self { + context, + operations: Vec::new(), + } + } + + pub(crate) fn evaluate<'a, T: 'a + Eq + Hash + GetAttributes + Display>( + &self, + medrecord: &'a MedRecord, + attributes: impl Iterator)> + 'a, + ) -> MedRecordResult)>> { + let attributes = Box::new(attributes) as BoxedIterator<(&'a T, Vec)>; + + self.operations + .iter() + .try_fold(attributes, |attribute_tuples, operation| { + operation.evaluate(medrecord, attribute_tuples) + }) + } + + implement_attributes_operation!(max, Max); + implement_attributes_operation!(min, Min); + implement_attributes_operation!(count, Count); + implement_attributes_operation!(sum, Sum); + implement_attributes_operation!(first, First); + implement_attributes_operation!(last, Last); + + implement_single_attribute_comparison_operation!( + greater_than, + AttributesTreeOperation, + GreaterThan + ); + implement_single_attribute_comparison_operation!( + greater_than_or_equal_to, + AttributesTreeOperation, + GreaterThanOrEqualTo + ); + implement_single_attribute_comparison_operation!(less_than, AttributesTreeOperation, LessThan); + implement_single_attribute_comparison_operation!( + less_than_or_equal_to, + AttributesTreeOperation, + LessThanOrEqualTo + ); + implement_single_attribute_comparison_operation!(equal_to, AttributesTreeOperation, EqualTo); + implement_single_attribute_comparison_operation!( + not_equal_to, + AttributesTreeOperation, + NotEqualTo + ); + implement_single_attribute_comparison_operation!( + starts_with, + AttributesTreeOperation, + StartsWith + ); + implement_single_attribute_comparison_operation!(ends_with, AttributesTreeOperation, EndsWith); + implement_single_attribute_comparison_operation!(contains, AttributesTreeOperation, Contains); + + pub fn is_in>(&mut self, attributes: V) { + self.operations.push( + AttributesTreeOperation::MultipleAttributesComparisonOperation { + operand: attributes.into(), + kind: MultipleComparisonKind::IsIn, + }, + ); + } + + pub fn is_not_in>(&mut self, attributes: V) { + self.operations.push( + AttributesTreeOperation::MultipleAttributesComparisonOperation { + operand: attributes.into(), + kind: MultipleComparisonKind::IsNotIn, + }, + ); + } + + implement_binary_arithmetic_operation!(add, AttributesTreeOperation, Add); + implement_binary_arithmetic_operation!(sub, AttributesTreeOperation, Sub); + implement_binary_arithmetic_operation!(mul, AttributesTreeOperation, Mul); + implement_binary_arithmetic_operation!(pow, AttributesTreeOperation, Pow); + implement_binary_arithmetic_operation!(r#mod, AttributesTreeOperation, Mod); + + implement_unary_arithmetic_operation!(abs, AttributesTreeOperation, Abs); + implement_unary_arithmetic_operation!(trim, AttributesTreeOperation, Trim); + implement_unary_arithmetic_operation!(trim_start, AttributesTreeOperation, TrimStart); + implement_unary_arithmetic_operation!(trim_end, AttributesTreeOperation, TrimEnd); + implement_unary_arithmetic_operation!(lowercase, AttributesTreeOperation, Lowercase); + implement_unary_arithmetic_operation!(uppercase, AttributesTreeOperation, Uppercase); + + pub fn slice(&mut self, start: usize, end: usize) { + self.operations + .push(AttributesTreeOperation::Slice(start..end)); + } + + implement_assertion_operation!(is_string, AttributesTreeOperation::IsString); + implement_assertion_operation!(is_int, AttributesTreeOperation::IsInt); + implement_assertion_operation!(is_max, AttributesTreeOperation::IsMax); + implement_assertion_operation!(is_min, AttributesTreeOperation::IsMin); + + pub fn either_or(&mut self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + let mut either_operand = Wrapper::::new(self.context.clone()); + let mut or_operand = Wrapper::::new(self.context.clone()); + + either_query(&mut either_operand); + or_query(&mut or_operand); + + self.operations.push(AttributesTreeOperation::EitherOr { + either: either_operand, + or: or_operand, + }); + } +} + +impl Wrapper { + pub(crate) fn new(context: Context) -> Self { + AttributesTreeOperand::new(context).into() + } + + pub(crate) fn evaluate<'a, T: 'a + Eq + Hash + GetAttributes + Display>( + &self, + medrecord: &'a MedRecord, + attributes: impl Iterator)> + 'a, + ) -> MedRecordResult)>> { + self.0.read_or_panic().evaluate(medrecord, attributes) + } + + implement_wrapper_operand_with_return!(max, MultipleAttributesOperand); + implement_wrapper_operand_with_return!(min, MultipleAttributesOperand); + implement_wrapper_operand_with_return!(count, MultipleAttributesOperand); + implement_wrapper_operand_with_return!(sum, MultipleAttributesOperand); + implement_wrapper_operand_with_return!(first, MultipleAttributesOperand); + implement_wrapper_operand_with_return!(last, MultipleAttributesOperand); + + implement_wrapper_operand_with_argument!( + greater_than, + impl Into + ); + implement_wrapper_operand_with_argument!( + greater_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!( + less_than, + impl Into + ); + implement_wrapper_operand_with_argument!( + less_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(equal_to, impl Into); + implement_wrapper_operand_with_argument!( + not_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!( + starts_with, + impl Into + ); + implement_wrapper_operand_with_argument!( + ends_with, + impl Into + ); + implement_wrapper_operand_with_argument!(contains, impl Into); + implement_wrapper_operand_with_argument!(is_in, impl Into); + implement_wrapper_operand_with_argument!( + is_not_in, + impl Into + ); + implement_wrapper_operand_with_argument!(add, impl Into); + implement_wrapper_operand_with_argument!(sub, impl Into); + implement_wrapper_operand_with_argument!(mul, impl Into); + implement_wrapper_operand_with_argument!(pow, impl Into); + implement_wrapper_operand_with_argument!(r#mod, impl Into); + + implement_wrapper_operand!(abs); + implement_wrapper_operand!(trim); + implement_wrapper_operand!(trim_start); + implement_wrapper_operand!(trim_end); + implement_wrapper_operand!(lowercase); + implement_wrapper_operand!(uppercase); + + pub fn slice(&self, start: usize, end: usize) { + self.0.write_or_panic().slice(start, end) + } + + implement_wrapper_operand!(is_string); + implement_wrapper_operand!(is_int); + implement_wrapper_operand!(is_max); + implement_wrapper_operand!(is_min); + + pub fn either_or(&self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + self.0.write_or_panic().either_or(either_query, or_query); + } +} + +#[derive(Debug, Clone)] +pub struct MultipleAttributesOperand { + pub(crate) context: AttributesTreeOperand, + pub(crate) kind: MultipleKind, + operations: Vec, +} + +impl DeepClone for MultipleAttributesOperand { + fn deep_clone(&self) -> Self { + Self { + context: self.context.clone(), + kind: self.kind.clone(), + operations: self.operations.iter().map(DeepClone::deep_clone).collect(), + } + } +} + +impl MultipleAttributesOperand { + pub(crate) fn new(context: AttributesTreeOperand, kind: MultipleKind) -> Self { + Self { + context, + kind, + operations: Vec::new(), + } + } + + pub(crate) fn evaluate<'a, T: 'a + Eq + Hash + GetAttributes + Display>( + &self, + medrecord: &'a MedRecord, + attributes: impl Iterator + 'a, + ) -> MedRecordResult> { + let attributes = Box::new(attributes) as BoxedIterator<(&'a T, MedRecordAttribute)>; + + self.operations + .iter() + .try_fold(attributes, |attribute_tuples, operation| { + operation.evaluate(medrecord, attribute_tuples) + }) + } + + implement_attribute_operation!(max, Max); + implement_attribute_operation!(min, Min); + implement_attribute_operation!(count, Count); + implement_attribute_operation!(sum, Sum); + implement_attribute_operation!(first, First); + implement_attribute_operation!(last, Last); + + implement_single_attribute_comparison_operation!( + greater_than, + MultipleAttributesOperation, + GreaterThan + ); + implement_single_attribute_comparison_operation!( + greater_than_or_equal_to, + MultipleAttributesOperation, + GreaterThanOrEqualTo + ); + implement_single_attribute_comparison_operation!( + less_than, + MultipleAttributesOperation, + LessThan + ); + implement_single_attribute_comparison_operation!( + less_than_or_equal_to, + MultipleAttributesOperation, + LessThanOrEqualTo + ); + implement_single_attribute_comparison_operation!( + equal_to, + MultipleAttributesOperation, + EqualTo + ); + implement_single_attribute_comparison_operation!( + not_equal_to, + MultipleAttributesOperation, + NotEqualTo + ); + implement_single_attribute_comparison_operation!( + starts_with, + MultipleAttributesOperation, + StartsWith + ); + implement_single_attribute_comparison_operation!( + ends_with, + MultipleAttributesOperation, + EndsWith + ); + implement_single_attribute_comparison_operation!( + contains, + MultipleAttributesOperation, + Contains + ); + + pub fn is_in>(&mut self, attributes: V) { + self.operations.push( + MultipleAttributesOperation::MultipleAttributesComparisonOperation { + operand: attributes.into(), + kind: MultipleComparisonKind::IsIn, + }, + ); + } + + pub fn is_not_in>(&mut self, attributes: V) { + self.operations.push( + MultipleAttributesOperation::MultipleAttributesComparisonOperation { + operand: attributes.into(), + kind: MultipleComparisonKind::IsNotIn, + }, + ); + } + + implement_binary_arithmetic_operation!(add, MultipleAttributesOperation, Add); + implement_binary_arithmetic_operation!(sub, MultipleAttributesOperation, Sub); + implement_binary_arithmetic_operation!(mul, MultipleAttributesOperation, Mul); + implement_binary_arithmetic_operation!(pow, MultipleAttributesOperation, Pow); + implement_binary_arithmetic_operation!(r#mod, MultipleAttributesOperation, Mod); + + implement_unary_arithmetic_operation!(abs, MultipleAttributesOperation, Abs); + implement_unary_arithmetic_operation!(trim, MultipleAttributesOperation, Trim); + implement_unary_arithmetic_operation!(trim_start, MultipleAttributesOperation, TrimStart); + implement_unary_arithmetic_operation!(trim_end, MultipleAttributesOperation, TrimEnd); + implement_unary_arithmetic_operation!(lowercase, MultipleAttributesOperation, Lowercase); + implement_unary_arithmetic_operation!(uppercase, MultipleAttributesOperation, Uppercase); + + #[allow(clippy::wrong_self_convention)] + pub fn to_values(&mut self) -> Wrapper { + let operand = Wrapper::::new( + values::Context::MultipleAttributesOperand(self.deep_clone()), + "unused".into(), + ); + + self.operations.push(MultipleAttributesOperation::ToValues { + operand: operand.clone(), + }); + + operand + } + + pub fn slice(&mut self, start: usize, end: usize) { + self.operations + .push(MultipleAttributesOperation::Slice(start..end)); + } + + implement_assertion_operation!(is_string, MultipleAttributesOperation::IsString); + implement_assertion_operation!(is_int, MultipleAttributesOperation::IsInt); + implement_assertion_operation!(is_max, MultipleAttributesOperation::IsMax); + implement_assertion_operation!(is_min, MultipleAttributesOperation::IsMin); + + pub fn either_or(&mut self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + let mut either_operand = + Wrapper::::new(self.context.clone(), self.kind.clone()); + let mut or_operand = + Wrapper::::new(self.context.clone(), self.kind.clone()); + + either_query(&mut either_operand); + or_query(&mut or_operand); + + self.operations.push(MultipleAttributesOperation::EitherOr { + either: either_operand, + or: or_operand, + }); + } +} + +impl Wrapper { + pub(crate) fn new(context: AttributesTreeOperand, kind: MultipleKind) -> Self { + MultipleAttributesOperand::new(context, kind).into() + } + + pub(crate) fn evaluate<'a, T: 'a + Eq + Hash + GetAttributes + Display>( + &self, + medrecord: &'a MedRecord, + attributes: impl Iterator + 'a, + ) -> MedRecordResult> { + self.0.read_or_panic().evaluate(medrecord, attributes) + } + + implement_wrapper_operand_with_return!(max, SingleAttributeOperand); + implement_wrapper_operand_with_return!(min, SingleAttributeOperand); + implement_wrapper_operand_with_return!(count, SingleAttributeOperand); + implement_wrapper_operand_with_return!(sum, SingleAttributeOperand); + implement_wrapper_operand_with_return!(first, SingleAttributeOperand); + implement_wrapper_operand_with_return!(last, SingleAttributeOperand); + + implement_wrapper_operand_with_argument!( + greater_than, + impl Into + ); + implement_wrapper_operand_with_argument!( + greater_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!( + less_than, + impl Into + ); + implement_wrapper_operand_with_argument!( + less_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(equal_to, impl Into); + implement_wrapper_operand_with_argument!( + not_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!( + starts_with, + impl Into + ); + implement_wrapper_operand_with_argument!( + ends_with, + impl Into + ); + implement_wrapper_operand_with_argument!(contains, impl Into); + implement_wrapper_operand_with_argument!(is_in, impl Into); + implement_wrapper_operand_with_argument!( + is_not_in, + impl Into + ); + implement_wrapper_operand_with_argument!(add, impl Into); + implement_wrapper_operand_with_argument!(sub, impl Into); + implement_wrapper_operand_with_argument!(mul, impl Into); + implement_wrapper_operand_with_argument!(pow, impl Into); + implement_wrapper_operand_with_argument!(r#mod, impl Into); + + implement_wrapper_operand!(abs); + implement_wrapper_operand!(trim); + implement_wrapper_operand!(trim_start); + implement_wrapper_operand!(trim_end); + implement_wrapper_operand!(lowercase); + implement_wrapper_operand!(uppercase); + + implement_wrapper_operand_with_return!(to_values, MultipleValuesOperand); + + pub fn slice(&self, start: usize, end: usize) { + self.0.write_or_panic().slice(start, end) + } + + implement_wrapper_operand!(is_string); + implement_wrapper_operand!(is_int); + implement_wrapper_operand!(is_max); + implement_wrapper_operand!(is_min); + + pub fn either_or(&self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + self.0.write_or_panic().either_or(either_query, or_query); + } +} + +#[derive(Debug, Clone)] +pub struct SingleAttributeOperand { + pub(crate) context: MultipleAttributesOperand, + pub(crate) kind: SingleKind, + operations: Vec, +} + +impl DeepClone for SingleAttributeOperand { + fn deep_clone(&self) -> Self { + Self { + context: self.context.deep_clone(), + kind: self.kind.clone(), + operations: self.operations.iter().map(DeepClone::deep_clone).collect(), + } + } +} + +impl SingleAttributeOperand { + pub(crate) fn new(context: MultipleAttributesOperand, kind: SingleKind) -> Self { + Self { + context, + kind, + operations: Vec::new(), + } + } + + pub(crate) fn evaluate( + &self, + medrecord: &MedRecord, + attribute: MedRecordAttribute, + ) -> MedRecordResult> { + self.operations + .iter() + .try_fold(Some(attribute), |attribute, operation| { + if let Some(attribute) = attribute { + operation.evaluate(medrecord, attribute) + } else { + Ok(None) + } + }) + } + + implement_single_attribute_comparison_operation!( + greater_than, + SingleAttributeOperation, + GreaterThan + ); + implement_single_attribute_comparison_operation!( + greater_than_or_equal_to, + SingleAttributeOperation, + GreaterThanOrEqualTo + ); + implement_single_attribute_comparison_operation!(less_than, SingleAttributeOperation, LessThan); + implement_single_attribute_comparison_operation!( + less_than_or_equal_to, + SingleAttributeOperation, + LessThanOrEqualTo + ); + implement_single_attribute_comparison_operation!(equal_to, SingleAttributeOperation, EqualTo); + implement_single_attribute_comparison_operation!( + not_equal_to, + SingleAttributeOperation, + NotEqualTo + ); + implement_single_attribute_comparison_operation!( + starts_with, + SingleAttributeOperation, + StartsWith + ); + implement_single_attribute_comparison_operation!(ends_with, SingleAttributeOperation, EndsWith); + implement_single_attribute_comparison_operation!(contains, SingleAttributeOperation, Contains); + + pub fn is_in>(&mut self, attributes: V) { + self.operations.push( + SingleAttributeOperation::MultipleAttributesComparisonOperation { + operand: attributes.into(), + kind: MultipleComparisonKind::IsIn, + }, + ); + } + + pub fn is_not_in>(&mut self, attributes: V) { + self.operations.push( + SingleAttributeOperation::MultipleAttributesComparisonOperation { + operand: attributes.into(), + kind: MultipleComparisonKind::IsNotIn, + }, + ); + } + + implement_binary_arithmetic_operation!(add, SingleAttributeOperation, Add); + implement_binary_arithmetic_operation!(sub, SingleAttributeOperation, Sub); + implement_binary_arithmetic_operation!(mul, SingleAttributeOperation, Mul); + implement_binary_arithmetic_operation!(pow, SingleAttributeOperation, Pow); + implement_binary_arithmetic_operation!(r#mod, SingleAttributeOperation, Mod); + + implement_unary_arithmetic_operation!(abs, SingleAttributeOperation, Abs); + implement_unary_arithmetic_operation!(trim, SingleAttributeOperation, Trim); + implement_unary_arithmetic_operation!(trim_start, SingleAttributeOperation, TrimStart); + implement_unary_arithmetic_operation!(trim_end, SingleAttributeOperation, TrimEnd); + implement_unary_arithmetic_operation!(lowercase, SingleAttributeOperation, Lowercase); + implement_unary_arithmetic_operation!(uppercase, SingleAttributeOperation, Uppercase); + + pub fn slice(&mut self, start: usize, end: usize) { + self.operations + .push(SingleAttributeOperation::Slice(start..end)); + } + + implement_assertion_operation!(is_string, SingleAttributeOperation::IsString); + implement_assertion_operation!(is_int, SingleAttributeOperation::IsInt); + + pub fn eiter_or(&mut self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + let mut either_operand = + Wrapper::::new(self.context.clone(), self.kind.clone()); + let mut or_operand = + Wrapper::::new(self.context.clone(), self.kind.clone()); + + either_query(&mut either_operand); + or_query(&mut or_operand); + + self.operations.push(SingleAttributeOperation::EitherOr { + either: either_operand, + or: or_operand, + }); + } +} + +impl Wrapper { + pub(crate) fn new(context: MultipleAttributesOperand, kind: SingleKind) -> Self { + SingleAttributeOperand::new(context, kind).into() + } + + pub(crate) fn evaluate( + &self, + medrecord: &MedRecord, + attribute: MedRecordAttribute, + ) -> MedRecordResult> { + self.0.read_or_panic().evaluate(medrecord, attribute) + } + + implement_wrapper_operand_with_argument!( + greater_than, + impl Into + ); + implement_wrapper_operand_with_argument!( + greater_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!( + less_than, + impl Into + ); + implement_wrapper_operand_with_argument!( + less_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(equal_to, impl Into); + implement_wrapper_operand_with_argument!( + not_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!( + starts_with, + impl Into + ); + implement_wrapper_operand_with_argument!( + ends_with, + impl Into + ); + implement_wrapper_operand_with_argument!(contains, impl Into); + implement_wrapper_operand_with_argument!(is_in, impl Into); + implement_wrapper_operand_with_argument!( + is_not_in, + impl Into + ); + implement_wrapper_operand_with_argument!(add, impl Into); + implement_wrapper_operand_with_argument!(sub, impl Into); + implement_wrapper_operand_with_argument!(mul, impl Into); + implement_wrapper_operand_with_argument!(pow, impl Into); + implement_wrapper_operand_with_argument!(r#mod, impl Into); + + implement_wrapper_operand!(abs); + implement_wrapper_operand!(trim); + implement_wrapper_operand!(trim_start); + implement_wrapper_operand!(trim_end); + implement_wrapper_operand!(lowercase); + implement_wrapper_operand!(uppercase); + + pub fn slice(&self, start: usize, end: usize) { + self.0.write_or_panic().slice(start, end) + } + + implement_wrapper_operand!(is_string); + implement_wrapper_operand!(is_int); + + pub fn either_or(&self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + self.0.write_or_panic().eiter_or(either_query, or_query); + } +} diff --git a/crates/medmodels-core/src/medrecord/querying/attributes/operation.rs b/crates/medmodels-core/src/medrecord/querying/attributes/operation.rs new file mode 100644 index 00000000..71479dff --- /dev/null +++ b/crates/medmodels-core/src/medrecord/querying/attributes/operation.rs @@ -0,0 +1,1357 @@ +use super::{ + operand::{ + MultipleAttributesComparisonOperand, MultipleAttributesOperand, + SingleAttributeComparisonOperand, SingleAttributeOperand, + }, + AttributesTreeOperand, BinaryArithmeticKind, GetAttributes, MultipleComparisonKind, + SingleComparisonKind, UnaryArithmeticKind, +}; +use crate::{ + errors::{MedRecordError, MedRecordResult}, + medrecord::{ + datatypes::{ + Abs, Contains, EndsWith, Lowercase, Mod, Pow, Slice, StartsWith, Trim, TrimEnd, + TrimStart, Uppercase, + }, + querying::{ + attributes::{MultipleKind, SingleKind}, + traits::{DeepClone, ReadWriteOrPanic}, + values::MultipleValuesOperand, + BoxedIterator, + }, + DataType, MedRecordAttribute, MedRecordValue, Wrapper, + }, + MedRecord, +}; +use itertools::Itertools; +use std::{ + cmp::Ordering, + collections::HashMap, + fmt::Display, + hash::Hash, + ops::{Add, Mul, Range, Sub}, +}; + +macro_rules! get_multiple_operand_attributes { + ($kind:ident, $attributes:expr) => { + match $kind { + MultipleKind::Max => Box::new(AttributesTreeOperation::get_max($attributes)?), + MultipleKind::Min => Box::new(AttributesTreeOperation::get_min($attributes)?), + MultipleKind::Count => Box::new(AttributesTreeOperation::get_count($attributes)?), + MultipleKind::Sum => Box::new(AttributesTreeOperation::get_sum($attributes)?), + MultipleKind::First => Box::new(AttributesTreeOperation::get_first($attributes)?), + MultipleKind::Last => Box::new(AttributesTreeOperation::get_last($attributes)?), + } + }; +} + +macro_rules! get_single_operand_attribute { + ($kind:ident, $attributes:expr) => { + match $kind { + SingleKind::Max => MultipleAttributesOperation::get_max($attributes)?.1, + SingleKind::Min => MultipleAttributesOperation::get_min($attributes)?.1, + SingleKind::Count => MultipleAttributesOperation::get_count($attributes), + SingleKind::Sum => MultipleAttributesOperation::get_sum($attributes)?, + SingleKind::First => MultipleAttributesOperation::get_first($attributes)?, + SingleKind::Last => MultipleAttributesOperation::get_last($attributes)?, + } + }; +} + +macro_rules! get_single_attribute_comparison_operand_attribute { + ($operand:ident, $medrecord:ident) => { + match $operand { + SingleAttributeComparisonOperand::Operand(operand) => { + let context = &operand.context.context.context; + let kind = &operand.context.kind; + + let comparison_attributes = context + .get_attributes($medrecord)? + .map(|attribute| (&0, attribute)); + + let comparison_attributes: Box> = + get_multiple_operand_attributes!(kind, comparison_attributes); + + let kind = &operand.kind; + + get_single_operand_attribute!(kind, comparison_attributes) + } + SingleAttributeComparisonOperand::Attribute(attribute) => attribute.clone(), + } + }; +} + +#[derive(Debug, Clone)] +pub enum AttributesTreeOperation { + AttributesOperation { + operand: Wrapper, + }, + SingleAttributeComparisonOperation { + operand: SingleAttributeComparisonOperand, + kind: SingleComparisonKind, + }, + MultipleAttributesComparisonOperation { + operand: MultipleAttributesComparisonOperand, + kind: MultipleComparisonKind, + }, + BinaryArithmeticOpration { + operand: SingleAttributeComparisonOperand, + kind: BinaryArithmeticKind, + }, + UnaryArithmeticOperation { + kind: UnaryArithmeticKind, + }, + + Slice(Range), + + IsString, + IsInt, + + IsMax, + IsMin, + + EitherOr { + either: Wrapper, + or: Wrapper, + }, +} + +impl DeepClone for AttributesTreeOperation { + fn deep_clone(&self) -> Self { + match self { + Self::AttributesOperation { operand } => Self::AttributesOperation { + operand: operand.deep_clone(), + }, + Self::SingleAttributeComparisonOperation { operand, kind } => { + Self::SingleAttributeComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::MultipleAttributesComparisonOperation { operand, kind } => { + Self::MultipleAttributesComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::BinaryArithmeticOpration { operand, kind } => Self::BinaryArithmeticOpration { + operand: operand.deep_clone(), + kind: kind.clone(), + }, + Self::UnaryArithmeticOperation { kind } => { + Self::UnaryArithmeticOperation { kind: kind.clone() } + } + Self::Slice(range) => Self::Slice(range.clone()), + Self::IsString => Self::IsString, + Self::IsInt => Self::IsInt, + Self::IsMax => Self::IsMax, + Self::IsMin => Self::IsMin, + Self::EitherOr { either, or } => Self::EitherOr { + either: either.deep_clone(), + or: or.deep_clone(), + }, + } + } +} + +impl AttributesTreeOperation { + pub(crate) fn evaluate<'a, T: Eq + Hash + GetAttributes + Display>( + &self, + medrecord: &'a MedRecord, + attributes: impl Iterator)> + 'a, + ) -> MedRecordResult)>> { + match self { + Self::AttributesOperation { operand } => Ok(Box::new( + Self::evaluate_attributes_operation(medrecord, attributes, operand)?, + )), + Self::SingleAttributeComparisonOperation { operand, kind } => { + Self::evaluate_single_attribute_comparison_operation( + medrecord, attributes, operand, kind, + ) + } + Self::MultipleAttributesComparisonOperation { operand, kind } => { + Self::evaluate_multiple_attributes_comparison_operation( + medrecord, attributes, operand, kind, + ) + } + Self::BinaryArithmeticOpration { operand, kind } => { + Self::evaluate_binary_arithmetic_operation(medrecord, attributes, operand, kind) + } + Self::UnaryArithmeticOperation { kind } => Ok(Box::new( + Self::evaluate_unary_arithmetic_operation(attributes, kind.clone()), + )), + Self::Slice(range) => Ok(Box::new(Self::evaluate_slice(attributes, range.clone()))), + Self::IsString => Ok(Box::new(attributes.map(|(index, attribute)| { + ( + index, + attribute + .into_iter() + .filter(|attribute| matches!(attribute, MedRecordAttribute::String(_))) + .collect(), + ) + }))), + Self::IsInt => Ok(Box::new(attributes.map(|(index, attribute)| { + ( + index, + attribute + .into_iter() + .filter(|attribute| matches!(attribute, MedRecordAttribute::String(_))) + .collect(), + ) + }))), + Self::IsMax => { + let max_attributes = Self::get_max(attributes)?; + + Ok(Box::new( + max_attributes.map(|(index, attribute)| (index, vec![attribute])), + )) + } + Self::IsMin => { + let min_attributes = Self::get_min(attributes)?; + + Ok(Box::new( + min_attributes.map(|(index, attribute)| (index, vec![attribute])), + )) + } + Self::EitherOr { either, or } => { + Self::evaluate_either_or(medrecord, attributes, either, or) + } + } + } + + #[inline] + pub(crate) fn get_max<'a, T: 'a>( + attributes: impl Iterator)>, + ) -> MedRecordResult> { + Ok(attributes.map(|(index, attributes)| { + let mut attributes = attributes.into_iter(); + + let first_attribute = attributes.next().ok_or(MedRecordError::QueryError( + "No attributes to compare".to_string(), + ))?; + + let attribute = attributes.try_fold(first_attribute, |max, attribute| { + match attribute.partial_cmp(&max) { + Some(Ordering::Greater) => Ok(attribute), + None => { + let first_dtype = DataType::from(attribute); + let second_dtype = DataType::from(max); + + Err(MedRecordError::QueryError(format!( + "Cannot compare attributes of data types {} and {}. Consider narrowing down the attributes using .is_string() or .is_int()", + first_dtype, second_dtype + ))) + } + _ => Ok(max), + } + })?; + + Ok((index, attribute)) + }).collect::>>()?.into_iter()) + } + + #[inline] + pub(crate) fn get_min<'a, T: 'a>( + attributes: impl Iterator)>, + ) -> MedRecordResult> { + Ok(attributes.map(|(index, attributes)| { + let mut attributes = attributes.into_iter(); + + let first_attribute = attributes.next().ok_or(MedRecordError::QueryError( + "No attributes to compare".to_string(), + ))?; + + let attribute = attributes.try_fold(first_attribute, |max, attribute| { + match attribute.partial_cmp(&max) { + Some(Ordering::Less) => Ok(attribute), + None => { + let first_dtype = DataType::from(attribute); + let second_dtype = DataType::from(max); + + Err(MedRecordError::QueryError(format!( + "Cannot compare attributes of data types {} and {}. Consider narrowing down the attributes using .is_string() or .is_int()", + first_dtype, second_dtype + ))) + } + _ => Ok(max), + } + })?; + + Ok((index, attribute)) + }).collect::>>()?.into_iter()) + } + + #[inline] + pub(crate) fn get_count<'a, T: 'a>( + attributes: impl Iterator)>, + ) -> MedRecordResult> { + Ok(attributes + .map(|(index, attribute)| (index, MedRecordAttribute::Int(attribute.len() as i64)))) + } + + #[inline] + pub(crate) fn get_sum<'a, T: 'a>( + attributes: impl Iterator)>, + ) -> MedRecordResult> { + Ok(attributes.map(|(index, attributes)| { + let mut attributes = attributes.into_iter(); + + let first_attribute = attributes.next().ok_or(MedRecordError::QueryError( + "No attributes to compare".to_string(), + ))?; + + let attribute = attributes.try_fold(first_attribute, |sum, attribute| { + let first_dtype = DataType::from(&sum); + let second_dtype = DataType::from(&attribute); + + sum.add(attribute).map_err(|_| { + MedRecordError::QueryError(format!( + "Cannot add attributes of data types {} and {}. Consider narrowing down the attributes using .is_string() or .is_int()", + first_dtype, second_dtype + )) + }) + })?; + + Ok((index, attribute)) + }).collect::>>()?.into_iter()) + } + + #[inline] + pub(crate) fn get_first<'a, T: 'a>( + attributes: impl Iterator)>, + ) -> MedRecordResult> { + Ok(attributes + .map(|(index, attributes)| { + let first_attribute = + attributes + .into_iter() + .next() + .ok_or(MedRecordError::QueryError( + "No attributes to compare".to_string(), + ))?; + + Ok((index, first_attribute)) + }) + .collect::>>()? + .into_iter()) + } + + #[inline] + pub(crate) fn get_last<'a, T: 'a>( + attributes: impl Iterator)>, + ) -> MedRecordResult> { + Ok(attributes + .map(|(index, attributes)| { + let first_attribute = + attributes + .into_iter() + .last() + .ok_or(MedRecordError::QueryError( + "No attributes to compare".to_string(), + ))?; + + Ok((index, first_attribute)) + }) + .collect::>>()? + .into_iter()) + } + + #[inline] + fn evaluate_attributes_operation<'a, T: 'a + Eq + Hash + GetAttributes + Display>( + medrecord: &'a MedRecord, + attributes: impl Iterator)>, + operand: &Wrapper, + ) -> MedRecordResult)>> { + let kind = &operand.0.read_or_panic().kind; + + let attributes = attributes.collect::>(); + + let multiple_operand_attributes: Box> = + get_multiple_operand_attributes!(kind, attributes.clone().into_iter()); + + let result = operand.evaluate(medrecord, multiple_operand_attributes)?; + + let mut attributes = attributes.into_iter().collect::>(); + + Ok(result + .map(move |(index, _)| (index, attributes.remove(&index).expect("Index must exist")))) + } + + #[inline] + fn evaluate_single_attribute_comparison_operation<'a, T>( + medrecord: &'a MedRecord, + attributes: impl Iterator)> + 'a, + comparison_operand: &SingleAttributeComparisonOperand, + kind: &SingleComparisonKind, + ) -> MedRecordResult)>> { + let comparison_attribute = + get_single_attribute_comparison_operand_attribute!(comparison_operand, medrecord); + + match kind { + SingleComparisonKind::GreaterThan => { + Ok(Box::new(attributes.map(move |(index, attributes)| { + ( + index, + attributes + .into_iter() + .filter(|attribute| attribute > &comparison_attribute) + .collect(), + ) + }))) + } + SingleComparisonKind::GreaterThanOrEqualTo => { + Ok(Box::new(attributes.map(move |(index, attributes)| { + ( + index, + attributes + .into_iter() + .filter(|attribute| attribute >= &comparison_attribute) + .collect(), + ) + }))) + } + SingleComparisonKind::LessThan => { + Ok(Box::new(attributes.map(move |(index, attributes)| { + ( + index, + attributes + .into_iter() + .filter(|attribute| attribute < &comparison_attribute) + .collect(), + ) + }))) + } + SingleComparisonKind::LessThanOrEqualTo => { + Ok(Box::new(attributes.map(move |(index, attributes)| { + ( + index, + attributes + .into_iter() + .filter(|attribute| attribute <= &comparison_attribute) + .collect(), + ) + }))) + } + SingleComparisonKind::EqualTo => { + Ok(Box::new(attributes.map(move |(index, attributes)| { + ( + index, + attributes + .into_iter() + .filter(|attribute| attribute == &comparison_attribute) + .collect(), + ) + }))) + } + SingleComparisonKind::NotEqualTo => { + Ok(Box::new(attributes.map(move |(index, attributes)| { + ( + index, + attributes + .into_iter() + .filter(|attribute| attribute != &comparison_attribute) + .collect(), + ) + }))) + } + SingleComparisonKind::StartsWith => { + Ok(Box::new(attributes.map(move |(index, attributes)| { + ( + index, + attributes + .into_iter() + .filter(|attribute| attribute.starts_with(&comparison_attribute)) + .collect(), + ) + }))) + } + SingleComparisonKind::EndsWith => { + Ok(Box::new(attributes.map(move |(index, attributes)| { + ( + index, + attributes + .into_iter() + .filter(|attribute| attribute.ends_with(&comparison_attribute)) + .collect(), + ) + }))) + } + SingleComparisonKind::Contains => { + Ok(Box::new(attributes.map(move |(index, attributes)| { + ( + index, + attributes + .into_iter() + .filter(|attribute| attribute.contains(&comparison_attribute)) + .collect(), + ) + }))) + } + } + } + + #[inline] + fn evaluate_multiple_attributes_comparison_operation<'a, T>( + medrecord: &'a MedRecord, + attributes: impl Iterator)> + 'a, + comparison_operand: &MultipleAttributesComparisonOperand, + kind: &MultipleComparisonKind, + ) -> MedRecordResult)>> { + let comparison_attributes = match comparison_operand { + MultipleAttributesComparisonOperand::Operand(operand) => { + let context = &operand.context.context; + let kind = &operand.kind; + + let comparison_attributes = context + .get_attributes(medrecord)? + .map(|attribute| (&0, attribute)); + + let comparison_attributes: Box> = + get_multiple_operand_attributes!(kind, comparison_attributes); + + comparison_attributes + .map(|(_, attribute)| attribute) + .collect::>() + } + MultipleAttributesComparisonOperand::Attributes(attributes) => attributes.clone(), + }; + + match kind { + MultipleComparisonKind::IsIn => { + Ok(Box::new(attributes.map(move |(index, attributes)| { + ( + index, + attributes + .into_iter() + .filter(|attribute| comparison_attributes.contains(attribute)) + .collect(), + ) + }))) + } + MultipleComparisonKind::IsNotIn => { + Ok(Box::new(attributes.map(move |(index, attributes)| { + ( + index, + attributes + .into_iter() + .filter(|attribute| !comparison_attributes.contains(attribute)) + .collect(), + ) + }))) + } + } + } + + #[inline] + fn evaluate_binary_arithmetic_operation<'a, T: 'a>( + medrecord: &'a MedRecord, + attributes: impl Iterator)> + 'a, + operand: &SingleAttributeComparisonOperand, + kind: &BinaryArithmeticKind, + ) -> MedRecordResult)>> { + let arithmetic_attribute = + get_single_attribute_comparison_operand_attribute!(operand, medrecord); + + let attributes: Box< + dyn Iterator)>>, + > = match kind { + BinaryArithmeticKind::Add => Box::new(attributes.map(move |(index, attributes)| { + Ok(( + index, + attributes + .into_iter() + .map(|attribute| attribute.add(arithmetic_attribute.clone())) + .collect::>>()?, + )) + })), + BinaryArithmeticKind::Sub => Box::new(attributes.map(move |(index, attributes)| { + Ok(( + index, + attributes + .into_iter() + .map(|attribute| attribute.sub(arithmetic_attribute.clone())) + .collect::>>()?, + )) + })), + BinaryArithmeticKind::Mul => Box::new(attributes.map(move |(index, attributes)| { + Ok(( + index, + attributes + .into_iter() + .map(|attribute| attribute.mul(arithmetic_attribute.clone())) + .collect::>>()?, + )) + })), + BinaryArithmeticKind::Pow => Box::new(attributes.map(move |(index, attributes)| { + Ok(( + index, + attributes + .into_iter() + .map(|attribute| attribute.pow(arithmetic_attribute.clone())) + .collect::>>()?, + )) + })), + BinaryArithmeticKind::Mod => Box::new(attributes.map(move |(index, attributes)| { + Ok(( + index, + attributes + .into_iter() + .map(|attribute| attribute.r#mod(arithmetic_attribute.clone())) + .collect::>>()?, + )) + })), + }; + + Ok(Box::new( + attributes.collect::>>()?.into_iter(), + )) + } + + #[inline] + fn evaluate_unary_arithmetic_operation<'a, T: 'a>( + attributes: impl Iterator)>, + kind: UnaryArithmeticKind, + ) -> impl Iterator)> { + attributes.map(move |(index, attributes)| { + ( + index, + attributes + .into_iter() + .map(|attribute| match kind { + UnaryArithmeticKind::Abs => attribute.abs(), + UnaryArithmeticKind::Trim => attribute.trim(), + UnaryArithmeticKind::TrimStart => attribute.trim_start(), + UnaryArithmeticKind::TrimEnd => attribute.trim_end(), + UnaryArithmeticKind::Lowercase => attribute.lowercase(), + UnaryArithmeticKind::Uppercase => attribute.uppercase(), + }) + .collect(), + ) + }) + } + + #[inline] + fn evaluate_slice<'a, T: 'a>( + attributes: impl Iterator)>, + range: Range, + ) -> impl Iterator)> { + attributes.map(move |(index, attributes)| { + ( + index, + attributes + .into_iter() + .map(|attribute| attribute.slice(range.clone())) + .collect(), + ) + }) + } + + #[inline] + fn evaluate_either_or<'a, T: 'a + Eq + Hash + GetAttributes + Display>( + medrecord: &'a MedRecord, + attributes: impl Iterator)>, + either: &Wrapper, + or: &Wrapper, + ) -> MedRecordResult)>> { + let attributes = attributes.collect::>(); + + let either_attributes = either.evaluate(medrecord, attributes.clone().into_iter())?; + let or_attributes = or.evaluate(medrecord, attributes.into_iter())?; + + Ok(Box::new( + either_attributes + .chain(or_attributes) + .unique_by(|attribute| attribute.0), + )) + } +} + +#[derive(Debug, Clone)] +pub enum MultipleAttributesOperation { + AttributeOperation { + operand: Wrapper, + }, + SingleAttributeComparisonOperation { + operand: SingleAttributeComparisonOperand, + kind: SingleComparisonKind, + }, + MultipleAttributesComparisonOperation { + operand: MultipleAttributesComparisonOperand, + kind: MultipleComparisonKind, + }, + BinaryArithmeticOpration { + operand: SingleAttributeComparisonOperand, + kind: BinaryArithmeticKind, + }, + UnaryArithmeticOperation { + kind: UnaryArithmeticKind, + }, + + ToValues { + operand: Wrapper, + }, + + Slice(Range), + + IsString, + IsInt, + + IsMax, + IsMin, + + EitherOr { + either: Wrapper, + or: Wrapper, + }, +} + +impl DeepClone for MultipleAttributesOperation { + fn deep_clone(&self) -> Self { + match self { + Self::AttributeOperation { operand } => Self::AttributeOperation { + operand: operand.deep_clone(), + }, + Self::SingleAttributeComparisonOperation { operand, kind } => { + Self::SingleAttributeComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::MultipleAttributesComparisonOperation { operand, kind } => { + Self::MultipleAttributesComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::BinaryArithmeticOpration { operand, kind } => Self::BinaryArithmeticOpration { + operand: operand.deep_clone(), + kind: kind.clone(), + }, + Self::UnaryArithmeticOperation { kind } => { + Self::UnaryArithmeticOperation { kind: kind.clone() } + } + Self::ToValues { operand } => Self::ToValues { + operand: operand.deep_clone(), + }, + Self::Slice(range) => Self::Slice(range.clone()), + Self::IsString => Self::IsString, + Self::IsInt => Self::IsInt, + Self::IsMax => Self::IsMax, + Self::IsMin => Self::IsMin, + Self::EitherOr { either, or } => Self::EitherOr { + either: either.deep_clone(), + or: or.deep_clone(), + }, + } + } +} + +impl MultipleAttributesOperation { + pub(crate) fn evaluate<'a, T: Eq + Hash + GetAttributes + Display>( + &self, + medrecord: &'a MedRecord, + attributes: impl Iterator + 'a, + ) -> MedRecordResult> { + match self { + Self::AttributeOperation { operand } => { + Self::evaluate_attribute_operation(medrecord, attributes, operand) + } + Self::SingleAttributeComparisonOperation { operand, kind } => { + Self::evaluate_single_attribute_comparison_operation( + medrecord, attributes, operand, kind, + ) + } + Self::MultipleAttributesComparisonOperation { operand, kind } => { + Self::evaluate_multiple_attributes_comparison_operation( + medrecord, attributes, operand, kind, + ) + } + Self::BinaryArithmeticOpration { operand, kind } => Ok(Box::new( + Self::evaluate_binary_arithmetic_operation(medrecord, attributes, operand, kind)?, + )), + Self::UnaryArithmeticOperation { kind } => Ok(Box::new( + Self::evaluate_unary_arithmetic_operation(attributes, kind.clone()), + )), + Self::ToValues { operand } => Ok(Box::new(Self::evaluate_to_values( + medrecord, attributes, operand, + )?)), + Self::Slice(range) => Ok(Box::new(Self::evaluate_slice(attributes, range.clone()))), + Self::IsString => { + Ok(Box::new(attributes.filter(|(_, attribute)| { + matches!(attribute, MedRecordAttribute::String(_)) + }))) + } + Self::IsInt => { + Ok(Box::new(attributes.filter(|(_, attribute)| { + matches!(attribute, MedRecordAttribute::Int(_)) + }))) + } + Self::IsMax => { + let max_attribute = Self::get_max(attributes)?; + + Ok(Box::new(std::iter::once(max_attribute))) + } + Self::IsMin => { + let min_attribute = Self::get_min(attributes)?; + + Ok(Box::new(std::iter::once(min_attribute))) + } + Self::EitherOr { either, or } => { + Self::evaluate_either_or(medrecord, attributes, either, or) + } + } + } + + #[inline] + pub(crate) fn get_max<'a, T>( + mut attributes: impl Iterator, + ) -> MedRecordResult<(&'a T, MedRecordAttribute)> { + let max_attribute = attributes.next().ok_or(MedRecordError::QueryError( + "No attributes to compare".to_string(), + ))?; + + attributes.try_fold(max_attribute, |max_attribute, attribute| { + match attribute.1.partial_cmp(&max_attribute.1) { + Some(Ordering::Greater) => Ok(attribute), + None => { + let first_dtype = DataType::from(attribute.1); + let second_dtype = DataType::from(max_attribute.1); + + Err(MedRecordError::QueryError(format!( + "Cannot compare attributes of data types {} and {}. Consider narrowing down the attributes using .is_string() or .is_int()", + first_dtype, second_dtype + ))) + } + _ => Ok(max_attribute), + } + }) + } + + #[inline] + pub(crate) fn get_min<'a, T>( + mut attributes: impl Iterator, + ) -> MedRecordResult<(&'a T, MedRecordAttribute)> { + let min_attribute = attributes.next().ok_or(MedRecordError::QueryError( + "No attributes to compare".to_string(), + ))?; + + attributes.try_fold(min_attribute, |min_attribute, attribute| { + match attribute.1.partial_cmp(&min_attribute.1) { + Some(Ordering::Less) => Ok(attribute), + None => { + let first_dtype = DataType::from(attribute.1); + let second_dtype = DataType::from(min_attribute.1); + + Err(MedRecordError::QueryError(format!( + "Cannot compare attributes of data types {} and {}. Consider narrowing down the attributes using .is_string() or .is_int()", + first_dtype, second_dtype + ))) + } + _ => Ok(min_attribute), + } + }) + } + + #[inline] + pub(crate) fn get_count<'a, T: 'a>( + attributes: impl Iterator, + ) -> MedRecordAttribute { + MedRecordAttribute::Int(attributes.count() as i64) + } + + #[inline] + // 🥊💥 + pub(crate) fn get_sum<'a, T: 'a>( + mut attributes: impl Iterator, + ) -> MedRecordResult { + let first_attribute = attributes.next().ok_or(MedRecordError::QueryError( + "No attributes to compare".to_string(), + ))?; + + attributes.try_fold(first_attribute.1, |sum, (_, attribute)| { + let first_dtype = DataType::from(&sum); + let second_dtype = DataType::from(&attribute); + + sum.add(attribute).map_err(|_| { + MedRecordError::QueryError(format!( + "Cannot add attributes of data types {} and {}. Consider narrowing down the attributes using .is_string() or .is_int()", + first_dtype, second_dtype + )) + }) + }) + } + + #[inline] + pub(crate) fn get_first<'a, T: 'a>( + mut attributes: impl Iterator, + ) -> MedRecordResult { + attributes + .next() + .ok_or(MedRecordError::QueryError( + "No attributes to get the first".to_string(), + )) + .map(|(_, attribute)| attribute) + } + + #[inline] + pub(crate) fn get_last<'a, T: 'a>( + attributes: impl Iterator, + ) -> MedRecordResult { + attributes + .last() + .ok_or(MedRecordError::QueryError( + "No attributes to get the first".to_string(), + )) + .map(|(_, attribute)| attribute) + } + + #[inline] + fn evaluate_attribute_operation<'a, T>( + medrecord: &'a MedRecord, + attribtues: impl Iterator, + operand: &Wrapper, + ) -> MedRecordResult> { + let kind = &operand.0.read_or_panic().kind; + + let attributes = attribtues.collect::>(); + + let attribute = get_single_operand_attribute!(kind, attributes.clone().into_iter()); + + Ok(match operand.evaluate(medrecord, attribute)? { + Some(_) => Box::new(attributes.into_iter()), + None => Box::new(std::iter::empty()), + }) + } + + #[inline] + fn evaluate_single_attribute_comparison_operation<'a, T>( + medrecord: &'a MedRecord, + attributes: impl Iterator + 'a, + comparison_operand: &SingleAttributeComparisonOperand, + kind: &SingleComparisonKind, + ) -> MedRecordResult> { + let comparison_attribute = + get_single_attribute_comparison_operand_attribute!(comparison_operand, medrecord); + + match kind { + SingleComparisonKind::GreaterThan => { + Ok(Box::new(attributes.filter(move |(_, attribute)| { + attribute > &comparison_attribute + }))) + } + SingleComparisonKind::GreaterThanOrEqualTo => { + Ok(Box::new(attributes.filter(move |(_, attribute)| { + attribute >= &comparison_attribute + }))) + } + SingleComparisonKind::LessThan => { + Ok(Box::new(attributes.filter(move |(_, attribute)| { + attribute < &comparison_attribute + }))) + } + SingleComparisonKind::LessThanOrEqualTo => { + Ok(Box::new(attributes.filter(move |(_, attribute)| { + attribute <= &comparison_attribute + }))) + } + SingleComparisonKind::EqualTo => { + Ok(Box::new(attributes.filter(move |(_, attribute)| { + attribute == &comparison_attribute + }))) + } + SingleComparisonKind::NotEqualTo => { + Ok(Box::new(attributes.filter(move |(_, attribute)| { + attribute != &comparison_attribute + }))) + } + SingleComparisonKind::StartsWith => { + Ok(Box::new(attributes.filter(move |(_, attribute)| { + attribute.starts_with(&comparison_attribute) + }))) + } + SingleComparisonKind::EndsWith => { + Ok(Box::new(attributes.filter(move |(_, attribute)| { + attribute.ends_with(&comparison_attribute) + }))) + } + SingleComparisonKind::Contains => { + Ok(Box::new(attributes.filter(move |(_, attribute)| { + attribute.contains(&comparison_attribute) + }))) + } + } + } + + #[inline] + fn evaluate_multiple_attributes_comparison_operation<'a, T>( + medrecord: &'a MedRecord, + attributes: impl Iterator + 'a, + comparison_operand: &MultipleAttributesComparisonOperand, + kind: &MultipleComparisonKind, + ) -> MedRecordResult> { + let comparison_attributes = match comparison_operand { + MultipleAttributesComparisonOperand::Operand(operand) => { + let context = &operand.context.context; + let kind = &operand.kind; + + let comparison_attributes = context + .get_attributes(medrecord)? + .map(|attribute| (&0, attribute)); + + let attributes: Box> = + get_multiple_operand_attributes!(kind, comparison_attributes); + + attributes + .map(|(_, attribute)| attribute) + .collect::>() + } + MultipleAttributesComparisonOperand::Attributes(attributes) => attributes.clone(), + }; + + match kind { + MultipleComparisonKind::IsIn => { + Ok(Box::new(attributes.filter(move |(_, attribute)| { + comparison_attributes.contains(attribute) + }))) + } + MultipleComparisonKind::IsNotIn => { + Ok(Box::new(attributes.filter(move |(_, attribute)| { + !comparison_attributes.contains(attribute) + }))) + } + } + } + + #[inline] + fn evaluate_binary_arithmetic_operation<'a, T: 'a>( + medrecord: &'a MedRecord, + attributes: impl Iterator, + operand: &SingleAttributeComparisonOperand, + kind: &BinaryArithmeticKind, + ) -> MedRecordResult> { + let arithmetic_attribute = + get_single_attribute_comparison_operand_attribute!(operand, medrecord); + + let attributes = attributes + .map(move |(t, attribute)| { + match kind { + BinaryArithmeticKind::Add => attribute.add(arithmetic_attribute.clone()), + BinaryArithmeticKind::Sub => attribute.sub(arithmetic_attribute.clone()), + BinaryArithmeticKind::Mul => { + attribute.clone().mul(arithmetic_attribute.clone()) + } + BinaryArithmeticKind::Pow => { + attribute.clone().pow(arithmetic_attribute.clone()) + } + BinaryArithmeticKind::Mod => { + attribute.clone().r#mod(arithmetic_attribute.clone()) + } + } + .map_err(|_| { + MedRecordError::QueryError(format!( + "Failed arithmetic operation {}. Consider narrowing down the attributes using .is_int() or .is_float()", + kind, + )) + }).map(|result| (t, result)) + }); + + // TODO: This is a temporary solution. It should be optimized. + Ok(attributes.collect::>>()?.into_iter()) + } + + #[inline] + fn evaluate_unary_arithmetic_operation<'a, T: 'a>( + attributes: impl Iterator, + kind: UnaryArithmeticKind, + ) -> impl Iterator { + attributes.map(move |(t, attribute)| { + let attribute = match kind { + UnaryArithmeticKind::Abs => attribute.abs(), + UnaryArithmeticKind::Trim => attribute.trim(), + UnaryArithmeticKind::TrimStart => attribute.trim_start(), + UnaryArithmeticKind::TrimEnd => attribute.trim_end(), + UnaryArithmeticKind::Lowercase => attribute.lowercase(), + UnaryArithmeticKind::Uppercase => attribute.uppercase(), + }; + (t, attribute) + }) + } + + pub(crate) fn get_values<'a, T: 'a + Eq + Hash + GetAttributes + Display>( + medrecord: &'a MedRecord, + attributes: impl Iterator, + ) -> MedRecordResult> { + Ok(attributes + .map(|(index, attribute)| { + let value = index.get_attributes(medrecord)?.get(&attribute).ok_or( + MedRecordError::QueryError(format!( + "Cannot find attribute {} for index {}", + attribute, index + )), + )?; + + Ok((index, value.clone())) + }) + .collect::>>()? + .into_iter()) + } + + #[inline] + fn evaluate_to_values<'a, T: 'a + Eq + Hash + GetAttributes + Display>( + medrecord: &'a MedRecord, + attributes: impl Iterator, + operand: &Wrapper, + ) -> MedRecordResult> { + let attributes = attributes.collect::>(); + + let values = Self::get_values(medrecord, attributes.clone().into_iter())?; + + let mut attributes = attributes.into_iter().collect::>(); + + let values = operand.evaluate(medrecord, values.into_iter())?; + + Ok(values.map(move |(index, _)| { + ( + index, + attributes.remove(&index).expect("Attribute must exist"), + ) + })) + } + + #[inline] + fn evaluate_slice<'a, T: 'a>( + attributes: impl Iterator, + range: Range, + ) -> impl Iterator { + attributes.map(move |(t, attribute)| (t, attribute.slice(range.clone()))) + } + + #[inline] + fn evaluate_either_or<'a, T: 'a + Eq + Hash + GetAttributes + Display>( + medrecord: &'a MedRecord, + attributes: impl Iterator, + either: &Wrapper, + or: &Wrapper, + ) -> MedRecordResult> { + let attributes = attributes.collect::>(); + + let either_attributes = either.evaluate(medrecord, attributes.clone().into_iter())?; + let or_attributes = or.evaluate(medrecord, attributes.into_iter())?; + + Ok(Box::new( + either_attributes + .chain(or_attributes) + .unique_by(|attribute| attribute.0), + )) + } +} + +#[derive(Debug, Clone)] +pub enum SingleAttributeOperation { + SingleAttributeComparisonOperation { + operand: SingleAttributeComparisonOperand, + kind: SingleComparisonKind, + }, + MultipleAttributesComparisonOperation { + operand: MultipleAttributesComparisonOperand, + kind: MultipleComparisonKind, + }, + BinaryArithmeticOpration { + operand: SingleAttributeComparisonOperand, + kind: BinaryArithmeticKind, + }, + UnaryArithmeticOperation { + kind: UnaryArithmeticKind, + }, + + Slice(Range), + + IsString, + IsInt, + + EitherOr { + either: Wrapper, + or: Wrapper, + }, +} + +impl DeepClone for SingleAttributeOperation { + fn deep_clone(&self) -> Self { + match self { + Self::SingleAttributeComparisonOperation { operand, kind } => { + Self::SingleAttributeComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::MultipleAttributesComparisonOperation { operand, kind } => { + Self::MultipleAttributesComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::BinaryArithmeticOpration { operand, kind } => Self::BinaryArithmeticOpration { + operand: operand.deep_clone(), + kind: kind.clone(), + }, + Self::UnaryArithmeticOperation { kind } => { + Self::UnaryArithmeticOperation { kind: kind.clone() } + } + Self::Slice(range) => Self::Slice(range.clone()), + Self::IsString => Self::IsString, + Self::IsInt => Self::IsInt, + Self::EitherOr { either, or } => Self::EitherOr { + either: either.deep_clone(), + or: or.deep_clone(), + }, + } + } +} + +impl SingleAttributeOperation { + pub(crate) fn evaluate( + &self, + medrecord: &MedRecord, + attribute: MedRecordAttribute, + ) -> MedRecordResult> { + match self { + Self::SingleAttributeComparisonOperation { operand, kind } => { + Self::evaluate_single_attribute_comparison_operation( + medrecord, attribute, operand, kind, + ) + } + Self::MultipleAttributesComparisonOperation { operand, kind } => { + Self::evaluate_multiple_attribute_comparison_operation( + medrecord, attribute, operand, kind, + ) + } + Self::BinaryArithmeticOpration { operand, kind } => { + Self::evaluate_binary_arithmetic_operation(medrecord, attribute, operand, kind) + } + Self::UnaryArithmeticOperation { kind } => Ok(Some(match kind { + UnaryArithmeticKind::Abs => attribute.abs(), + UnaryArithmeticKind::Trim => attribute.trim(), + UnaryArithmeticKind::TrimStart => attribute.trim_start(), + UnaryArithmeticKind::TrimEnd => attribute.trim_end(), + UnaryArithmeticKind::Lowercase => attribute.lowercase(), + UnaryArithmeticKind::Uppercase => attribute.uppercase(), + })), + Self::Slice(range) => Ok(Some(attribute.slice(range.clone()))), + Self::IsString => Ok(match attribute { + MedRecordAttribute::String(_) => Some(attribute), + _ => None, + }), + Self::IsInt => Ok(match attribute { + MedRecordAttribute::Int(_) => Some(attribute), + _ => None, + }), + Self::EitherOr { either, or } => { + Self::evaluate_either_or(medrecord, attribute, either, or) + } + } + } + + #[inline] + fn evaluate_single_attribute_comparison_operation( + medrecord: &MedRecord, + attribute: MedRecordAttribute, + comparison_operand: &SingleAttributeComparisonOperand, + kind: &SingleComparisonKind, + ) -> MedRecordResult> { + let comparison_attribute = + get_single_attribute_comparison_operand_attribute!(comparison_operand, medrecord); + + let comparison_result = match kind { + SingleComparisonKind::GreaterThan => attribute > comparison_attribute, + SingleComparisonKind::GreaterThanOrEqualTo => attribute >= comparison_attribute, + SingleComparisonKind::LessThan => attribute < comparison_attribute, + SingleComparisonKind::LessThanOrEqualTo => attribute <= comparison_attribute, + SingleComparisonKind::EqualTo => attribute == comparison_attribute, + SingleComparisonKind::NotEqualTo => attribute != comparison_attribute, + SingleComparisonKind::StartsWith => attribute.starts_with(&comparison_attribute), + SingleComparisonKind::EndsWith => attribute.ends_with(&comparison_attribute), + SingleComparisonKind::Contains => attribute.contains(&comparison_attribute), + }; + + Ok(if comparison_result { + Some(attribute) + } else { + None + }) + } + + #[inline] + fn evaluate_multiple_attribute_comparison_operation( + medrecord: &MedRecord, + attribute: MedRecordAttribute, + comparison_operand: &MultipleAttributesComparisonOperand, + kind: &MultipleComparisonKind, + ) -> MedRecordResult> { + let comparison_attributes = match comparison_operand { + MultipleAttributesComparisonOperand::Operand(operand) => { + let context = &operand.context.context; + let kind = &operand.kind; + + let comparison_attributes = context + .get_attributes(medrecord)? + .map(|attribute| (&0, attribute)); + + let attributes: Box> = + get_multiple_operand_attributes!(kind, comparison_attributes); + + attributes + .map(|(_, attribute)| attribute) + .collect::>() + } + MultipleAttributesComparisonOperand::Attributes(attributes) => attributes.clone(), + }; + + let comparison_result = match kind { + MultipleComparisonKind::IsIn => comparison_attributes.contains(&attribute), + MultipleComparisonKind::IsNotIn => !comparison_attributes.contains(&attribute), + }; + + Ok(if comparison_result { + Some(attribute) + } else { + None + }) + } + + #[inline] + fn evaluate_binary_arithmetic_operation( + medrecord: &MedRecord, + attribute: MedRecordAttribute, + operand: &SingleAttributeComparisonOperand, + kind: &BinaryArithmeticKind, + ) -> MedRecordResult> { + let arithmetic_attribute = + get_single_attribute_comparison_operand_attribute!(operand, medrecord); + + match kind { + BinaryArithmeticKind::Add => attribute.add(arithmetic_attribute), + BinaryArithmeticKind::Sub => attribute.sub(arithmetic_attribute), + BinaryArithmeticKind::Mul => attribute.mul(arithmetic_attribute), + BinaryArithmeticKind::Pow => attribute.pow(arithmetic_attribute), + BinaryArithmeticKind::Mod => attribute.r#mod(arithmetic_attribute), + } + .map(Some) + } + + #[inline] + fn evaluate_either_or( + medrecord: &MedRecord, + attribute: MedRecordAttribute, + either: &Wrapper, + or: &Wrapper, + ) -> MedRecordResult> { + let either_result = either.evaluate(medrecord, attribute.clone())?; + let or_result = or.evaluate(medrecord, attribute)?; + + match (either_result, or_result) { + (Some(either_result), _) => Ok(Some(either_result)), + (None, Some(or_result)) => Ok(Some(or_result)), + _ => Ok(None), + } + } +} diff --git a/crates/medmodels-core/src/medrecord/querying/edges/mod.rs b/crates/medmodels-core/src/medrecord/querying/edges/mod.rs new file mode 100644 index 00000000..1045e83e --- /dev/null +++ b/crates/medmodels-core/src/medrecord/querying/edges/mod.rs @@ -0,0 +1,58 @@ +mod operand; +mod operation; +mod selection; + +pub use operand::EdgeOperand; +pub use operation::EdgeOperation; +pub use selection::EdgeSelection; +use std::fmt::Display; + +#[derive(Debug, Clone)] +pub enum SingleKind { + Max, + Min, + Count, + Sum, + First, + Last, +} + +#[derive(Debug, Clone)] +pub enum SingleComparisonKind { + GreaterThan, + GreaterThanOrEqualTo, + LessThan, + LessThanOrEqualTo, + EqualTo, + NotEqualTo, + StartsWith, + EndsWith, + Contains, +} + +#[derive(Debug, Clone)] +pub enum MultipleComparisonKind { + IsIn, + IsNotIn, +} + +#[derive(Debug, Clone)] +pub enum BinaryArithmeticKind { + Add, + Sub, + Mul, + Pow, + Mod, +} + +impl Display for BinaryArithmeticKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + BinaryArithmeticKind::Add => write!(f, "add"), + BinaryArithmeticKind::Sub => write!(f, "sub"), + BinaryArithmeticKind::Mul => write!(f, "mul"), + BinaryArithmeticKind::Pow => write!(f, "pow"), + BinaryArithmeticKind::Mod => write!(f, "mod"), + } + } +} diff --git a/crates/medmodels-core/src/medrecord/querying/edges/operand.rs b/crates/medmodels-core/src/medrecord/querying/edges/operand.rs new file mode 100644 index 00000000..4b7b4f85 --- /dev/null +++ b/crates/medmodels-core/src/medrecord/querying/edges/operand.rs @@ -0,0 +1,655 @@ +use super::{ + operation::{EdgeIndexOperation, EdgeIndicesOperation, EdgeOperation}, + BinaryArithmeticKind, MultipleComparisonKind, SingleComparisonKind, SingleKind, +}; +use crate::{ + errors::MedRecordResult, + medrecord::{ + querying::{ + attributes::{self, AttributesTreeOperand}, + nodes::NodeOperand, + traits::{DeepClone, ReadWriteOrPanic}, + values::{self, MultipleValuesOperand}, + wrapper::Wrapper, + BoxedIterator, + }, + CardinalityWrapper, EdgeIndex, Group, MedRecordAttribute, + }, + MedRecord, +}; +use std::fmt::Debug; + +#[derive(Debug, Clone)] +pub struct EdgeOperand { + pub(crate) operations: Vec, +} + +impl DeepClone for EdgeOperand { + fn deep_clone(&self) -> Self { + Self { + operations: self + .operations + .iter() + .map(|operation| operation.deep_clone()) + .collect(), + } + } +} + +impl EdgeOperand { + pub(crate) fn new() -> Self { + Self { + operations: Vec::new(), + } + } + + pub(crate) fn evaluate<'a>( + &self, + medrecord: &'a MedRecord, + ) -> MedRecordResult> { + let edge_indices = Box::new(medrecord.edge_indices()) as BoxedIterator<&'a EdgeIndex>; + + self.operations + .iter() + .try_fold(edge_indices, |edge_indices, operation| { + operation.evaluate(medrecord, edge_indices) + }) + } + + pub fn attribute(&mut self, attribute: MedRecordAttribute) -> Wrapper { + let operand = Wrapper::::new( + values::Context::EdgeOperand(self.deep_clone()), + attribute, + ); + + self.operations.push(EdgeOperation::Values { + operand: operand.clone(), + }); + + operand + } + + pub fn attributes(&mut self) -> Wrapper { + let operand = Wrapper::::new(attributes::Context::EdgeOperand( + self.deep_clone(), + )); + + self.operations.push(EdgeOperation::Attributes { + operand: operand.clone(), + }); + + operand + } + + pub fn index(&mut self) -> Wrapper { + let operand = Wrapper::::new(self.deep_clone()); + + self.operations.push(EdgeOperation::Indices { + operand: operand.clone(), + }); + + operand + } + + pub fn in_group(&mut self, group: G) + where + G: Into>, + { + self.operations.push(EdgeOperation::InGroup { + group: group.into(), + }); + } + + pub fn has_attribute(&mut self, attribute: A) + where + A: Into>, + { + self.operations.push(EdgeOperation::HasAttribute { + attribute: attribute.into(), + }); + } + + pub fn source_node(&mut self) -> Wrapper { + let operand = Wrapper::::new(); + + self.operations.push(EdgeOperation::SourceNode { + operand: operand.clone(), + }); + + operand + } + + pub fn target_node(&mut self) -> Wrapper { + let operand = Wrapper::::new(); + + self.operations.push(EdgeOperation::TargetNode { + operand: operand.clone(), + }); + + operand + } + + pub fn either_or(&mut self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + let mut either_operand = Wrapper::::new(); + let mut or_operand = Wrapper::::new(); + + either_query(&mut either_operand); + or_query(&mut or_operand); + + self.operations.push(EdgeOperation::EitherOr { + either: either_operand, + or: or_operand, + }); + } +} + +impl Wrapper { + pub(crate) fn new() -> Self { + EdgeOperand::new().into() + } + + pub(crate) fn evaluate<'a>( + &self, + medrecord: &'a MedRecord, + ) -> MedRecordResult> { + self.0.read_or_panic().evaluate(medrecord) + } + + pub fn attribute(&self, attribute: A) -> Wrapper + where + A: Into, + { + self.0.write_or_panic().attribute(attribute.into()) + } + + pub fn attributes(&self) -> Wrapper { + self.0.write_or_panic().attributes() + } + + pub fn index(&self) -> Wrapper { + self.0.write_or_panic().index() + } + + pub fn in_group(&mut self, group: G) + where + G: Into>, + { + self.0.write_or_panic().in_group(group); + } + + pub fn has_attribute(&mut self, attribute: A) + where + A: Into>, + { + self.0.write_or_panic().has_attribute(attribute); + } + + pub fn source_node(&self) -> Wrapper { + self.0.write_or_panic().source_node() + } + + pub fn target_node(&self) -> Wrapper { + self.0.write_or_panic().target_node() + } + + pub fn either_or(&self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + self.0.write_or_panic().either_or(either_query, or_query); + } +} + +macro_rules! implement_value_operation { + ($name:ident, $variant:ident) => { + pub fn $name(&mut self) -> Wrapper { + let operand = Wrapper::::new(self.deep_clone(), SingleKind::$variant); + + self.operations + .push(EdgeIndicesOperation::EdgeIndexOperation { + operand: operand.clone(), + }); + + operand + } + }; +} + +macro_rules! implement_single_value_comparison_operation { + ($name:ident, $operation:ident, $kind:ident) => { + pub fn $name>(&mut self, value: V) { + self.operations + .push($operation::EdgeIndexComparisonOperation { + operand: value.into(), + kind: SingleComparisonKind::$kind, + }); + } + }; +} + +macro_rules! implement_binary_arithmetic_operation { + ($name:ident, $operation:ident, $kind:ident) => { + pub fn $name>(&mut self, value: V) { + self.operations.push($operation::BinaryArithmeticOpration { + operand: value.into(), + kind: BinaryArithmeticKind::$kind, + }); + } + }; +} + +macro_rules! implement_assertion_operation { + ($name:ident, $operation:expr) => { + pub fn $name(&mut self) { + self.operations.push($operation); + } + }; +} + +macro_rules! implement_wrapper_operand { + ($name:ident) => { + pub fn $name(&self) { + self.0.write_or_panic().$name() + } + }; +} + +macro_rules! implement_wrapper_operand_with_return { + ($name:ident, $return_operand:ident) => { + pub fn $name(&self) -> Wrapper<$return_operand> { + self.0.write_or_panic().$name() + } + }; +} + +macro_rules! implement_wrapper_operand_with_argument { + ($name:ident, $value_type:ty) => { + pub fn $name(&self, value: $value_type) { + self.0.write_or_panic().$name(value) + } + }; +} + +#[derive(Debug, Clone)] +pub enum EdgeIndexComparisonOperand { + Operand(EdgeIndexOperand), + Index(EdgeIndex), +} + +impl DeepClone for EdgeIndexComparisonOperand { + fn deep_clone(&self) -> Self { + match self { + Self::Operand(operand) => Self::Operand(operand.deep_clone()), + Self::Index(value) => Self::Index(*value), + } + } +} + +impl From> for EdgeIndexComparisonOperand { + fn from(value: Wrapper) -> Self { + Self::Operand(value.0.read_or_panic().deep_clone()) + } +} + +impl From<&Wrapper> for EdgeIndexComparisonOperand { + fn from(value: &Wrapper) -> Self { + Self::Operand(value.0.read_or_panic().deep_clone()) + } +} + +impl> From for EdgeIndexComparisonOperand { + fn from(value: V) -> Self { + Self::Index(value.into()) + } +} + +#[derive(Debug, Clone)] +pub enum EdgeIndicesComparisonOperand { + Operand(EdgeIndicesOperand), + Indices(Vec), +} + +impl DeepClone for EdgeIndicesComparisonOperand { + fn deep_clone(&self) -> Self { + match self { + Self::Operand(operand) => Self::Operand(operand.deep_clone()), + Self::Indices(value) => Self::Indices(value.clone()), + } + } +} + +impl From> for EdgeIndicesComparisonOperand { + fn from(value: Wrapper) -> Self { + Self::Operand(value.0.read_or_panic().deep_clone()) + } +} + +impl From<&Wrapper> for EdgeIndicesComparisonOperand { + fn from(value: &Wrapper) -> Self { + Self::Operand(value.0.read_or_panic().deep_clone()) + } +} + +impl> From> for EdgeIndicesComparisonOperand { + fn from(value: Vec) -> Self { + Self::Indices(value.into_iter().map(Into::into).collect()) + } +} + +impl + Clone, const N: usize> From<[V; N]> for EdgeIndicesComparisonOperand { + fn from(value: [V; N]) -> Self { + value.to_vec().into() + } +} + +#[derive(Debug, Clone)] +pub struct EdgeIndicesOperand { + pub(crate) context: EdgeOperand, + operations: Vec, +} + +impl DeepClone for EdgeIndicesOperand { + fn deep_clone(&self) -> Self { + Self { + context: self.context.clone(), + operations: self.operations.iter().map(DeepClone::deep_clone).collect(), + } + } +} + +impl EdgeIndicesOperand { + pub(crate) fn new(context: EdgeOperand) -> Self { + Self { + context, + operations: Vec::new(), + } + } + + pub(crate) fn evaluate<'a>( + &self, + medrecord: &'a MedRecord, + values: impl Iterator + 'a, + ) -> MedRecordResult + 'a> { + let values = Box::new(values) as BoxedIterator; + + self.operations + .iter() + .try_fold(values, |value_tuples, operation| { + operation.evaluate(medrecord, value_tuples) + }) + } + + implement_value_operation!(max, Max); + implement_value_operation!(min, Min); + implement_value_operation!(count, Count); + implement_value_operation!(sum, Sum); + implement_value_operation!(first, First); + implement_value_operation!(last, Last); + + implement_single_value_comparison_operation!(greater_than, EdgeIndicesOperation, GreaterThan); + implement_single_value_comparison_operation!( + greater_than_or_equal_to, + EdgeIndicesOperation, + GreaterThanOrEqualTo + ); + implement_single_value_comparison_operation!(less_than, EdgeIndicesOperation, LessThan); + implement_single_value_comparison_operation!( + less_than_or_equal_to, + EdgeIndicesOperation, + LessThanOrEqualTo + ); + implement_single_value_comparison_operation!(equal_to, EdgeIndicesOperation, EqualTo); + implement_single_value_comparison_operation!(not_equal_to, EdgeIndicesOperation, NotEqualTo); + implement_single_value_comparison_operation!(starts_with, EdgeIndicesOperation, StartsWith); + implement_single_value_comparison_operation!(ends_with, EdgeIndicesOperation, EndsWith); + implement_single_value_comparison_operation!(contains, EdgeIndicesOperation, Contains); + + pub fn is_in>(&mut self, values: V) { + self.operations + .push(EdgeIndicesOperation::EdgeIndicesComparisonOperation { + operand: values.into(), + kind: MultipleComparisonKind::IsIn, + }); + } + + pub fn is_not_in>(&mut self, values: V) { + self.operations + .push(EdgeIndicesOperation::EdgeIndicesComparisonOperation { + operand: values.into(), + kind: MultipleComparisonKind::IsNotIn, + }); + } + + implement_binary_arithmetic_operation!(add, EdgeIndicesOperation, Add); + implement_binary_arithmetic_operation!(sub, EdgeIndicesOperation, Sub); + implement_binary_arithmetic_operation!(mul, EdgeIndicesOperation, Mul); + implement_binary_arithmetic_operation!(pow, EdgeIndicesOperation, Pow); + implement_binary_arithmetic_operation!(r#mod, EdgeIndicesOperation, Mod); + + implement_assertion_operation!(is_max, EdgeIndicesOperation::IsMax); + implement_assertion_operation!(is_min, EdgeIndicesOperation::IsMin); + + pub fn either_or(&mut self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + let mut either_operand = Wrapper::::new(self.context.clone()); + let mut or_operand = Wrapper::::new(self.context.clone()); + + either_query(&mut either_operand); + or_query(&mut or_operand); + + self.operations.push(EdgeIndicesOperation::EitherOr { + either: either_operand, + or: or_operand, + }); + } +} + +impl Wrapper { + pub(crate) fn new(context: EdgeOperand) -> Self { + EdgeIndicesOperand::new(context).into() + } + + pub(crate) fn evaluate<'a>( + &self, + medrecord: &'a MedRecord, + values: impl Iterator + 'a, + ) -> MedRecordResult + 'a> { + self.0.read_or_panic().evaluate(medrecord, values) + } + + implement_wrapper_operand_with_return!(max, EdgeIndexOperand); + implement_wrapper_operand_with_return!(min, EdgeIndexOperand); + implement_wrapper_operand_with_return!(count, EdgeIndexOperand); + implement_wrapper_operand_with_return!(sum, EdgeIndexOperand); + implement_wrapper_operand_with_return!(first, EdgeIndexOperand); + implement_wrapper_operand_with_return!(last, EdgeIndexOperand); + + implement_wrapper_operand_with_argument!(greater_than, impl Into); + implement_wrapper_operand_with_argument!( + greater_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(less_than, impl Into); + implement_wrapper_operand_with_argument!( + less_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(equal_to, impl Into); + implement_wrapper_operand_with_argument!(not_equal_to, impl Into); + implement_wrapper_operand_with_argument!(starts_with, impl Into); + implement_wrapper_operand_with_argument!(ends_with, impl Into); + implement_wrapper_operand_with_argument!(contains, impl Into); + implement_wrapper_operand_with_argument!(is_in, impl Into); + implement_wrapper_operand_with_argument!(is_not_in, impl Into); + implement_wrapper_operand_with_argument!(add, impl Into); + implement_wrapper_operand_with_argument!(sub, impl Into); + implement_wrapper_operand_with_argument!(mul, impl Into); + implement_wrapper_operand_with_argument!(pow, impl Into); + implement_wrapper_operand_with_argument!(r#mod, impl Into); + + implement_wrapper_operand!(is_max); + implement_wrapper_operand!(is_min); + + pub fn either_or(&self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + self.0.write_or_panic().either_or(either_query, or_query); + } +} + +#[derive(Debug, Clone)] +pub struct EdgeIndexOperand { + pub(crate) context: EdgeIndicesOperand, + pub(crate) kind: SingleKind, + operations: Vec, +} + +impl DeepClone for EdgeIndexOperand { + fn deep_clone(&self) -> Self { + Self { + context: self.context.deep_clone(), + kind: self.kind.clone(), + operations: self.operations.iter().map(DeepClone::deep_clone).collect(), + } + } +} + +impl EdgeIndexOperand { + pub(crate) fn new(context: EdgeIndicesOperand, kind: SingleKind) -> Self { + Self { + context, + kind, + operations: Vec::new(), + } + } + + pub(crate) fn evaluate( + &self, + medrecord: &MedRecord, + value: EdgeIndex, + ) -> MedRecordResult> { + self.operations + .iter() + .try_fold(Some(value), |value, operation| { + if let Some(value) = value { + operation.evaluate(medrecord, value) + } else { + Ok(None) + } + }) + } + + implement_single_value_comparison_operation!(greater_than, EdgeIndexOperation, GreaterThan); + implement_single_value_comparison_operation!( + greater_than_or_equal_to, + EdgeIndexOperation, + GreaterThanOrEqualTo + ); + implement_single_value_comparison_operation!(less_than, EdgeIndexOperation, LessThan); + implement_single_value_comparison_operation!( + less_than_or_equal_to, + EdgeIndexOperation, + LessThanOrEqualTo + ); + implement_single_value_comparison_operation!(equal_to, EdgeIndexOperation, EqualTo); + implement_single_value_comparison_operation!(not_equal_to, EdgeIndexOperation, NotEqualTo); + implement_single_value_comparison_operation!(starts_with, EdgeIndexOperation, StartsWith); + implement_single_value_comparison_operation!(ends_with, EdgeIndexOperation, EndsWith); + implement_single_value_comparison_operation!(contains, EdgeIndexOperation, Contains); + + pub fn is_in>(&mut self, values: V) { + self.operations + .push(EdgeIndexOperation::EdgeIndicesComparisonOperation { + operand: values.into(), + kind: MultipleComparisonKind::IsIn, + }); + } + + pub fn is_not_in>(&mut self, values: V) { + self.operations + .push(EdgeIndexOperation::EdgeIndicesComparisonOperation { + operand: values.into(), + kind: MultipleComparisonKind::IsNotIn, + }); + } + + implement_binary_arithmetic_operation!(add, EdgeIndexOperation, Add); + implement_binary_arithmetic_operation!(sub, EdgeIndexOperation, Sub); + implement_binary_arithmetic_operation!(mul, EdgeIndexOperation, Mul); + implement_binary_arithmetic_operation!(pow, EdgeIndexOperation, Pow); + implement_binary_arithmetic_operation!(r#mod, EdgeIndexOperation, Mod); + + pub fn eiter_or(&mut self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + let mut either_operand = + Wrapper::::new(self.context.clone(), self.kind.clone()); + let mut or_operand = + Wrapper::::new(self.context.clone(), self.kind.clone()); + + either_query(&mut either_operand); + or_query(&mut or_operand); + + self.operations.push(EdgeIndexOperation::EitherOr { + either: either_operand, + or: or_operand, + }); + } +} + +impl Wrapper { + pub(crate) fn new(context: EdgeIndicesOperand, kind: SingleKind) -> Self { + EdgeIndexOperand::new(context, kind).into() + } + + pub(crate) fn evaluate( + &self, + medrecord: &MedRecord, + value: EdgeIndex, + ) -> MedRecordResult> { + self.0.read_or_panic().evaluate(medrecord, value) + } + + implement_wrapper_operand_with_argument!(greater_than, impl Into); + implement_wrapper_operand_with_argument!( + greater_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(less_than, impl Into); + implement_wrapper_operand_with_argument!( + less_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(equal_to, impl Into); + implement_wrapper_operand_with_argument!(not_equal_to, impl Into); + implement_wrapper_operand_with_argument!(starts_with, impl Into); + implement_wrapper_operand_with_argument!(ends_with, impl Into); + implement_wrapper_operand_with_argument!(contains, impl Into); + implement_wrapper_operand_with_argument!(is_in, impl Into); + implement_wrapper_operand_with_argument!(is_not_in, impl Into); + implement_wrapper_operand_with_argument!(add, impl Into); + implement_wrapper_operand_with_argument!(sub, impl Into); + implement_wrapper_operand_with_argument!(mul, impl Into); + implement_wrapper_operand_with_argument!(pow, impl Into); + implement_wrapper_operand_with_argument!(r#mod, impl Into); + + pub fn either_or(&self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + self.0.write_or_panic().eiter_or(either_query, or_query); + } +} diff --git a/crates/medmodels-core/src/medrecord/querying/edges/operation.rs b/crates/medmodels-core/src/medrecord/querying/edges/operation.rs new file mode 100644 index 00000000..0d36db8f --- /dev/null +++ b/crates/medmodels-core/src/medrecord/querying/edges/operation.rs @@ -0,0 +1,762 @@ +use super::{ + operand::{ + EdgeIndexComparisonOperand, EdgeIndexOperand, EdgeIndicesComparisonOperand, + EdgeIndicesOperand, + }, + BinaryArithmeticKind, EdgeOperand, MultipleComparisonKind, SingleComparisonKind, +}; +use crate::{ + errors::{MedRecordError, MedRecordResult}, + medrecord::{ + datatypes::{Contains, EndsWith, Mod, StartsWith}, + querying::{ + attributes::AttributesTreeOperand, + edges::SingleKind, + nodes::NodeOperand, + traits::{DeepClone, ReadWriteOrPanic}, + values::MultipleValuesOperand, + wrapper::Wrapper, + BoxedIterator, + }, + CardinalityWrapper, EdgeIndex, Group, MedRecordAttribute, MedRecordValue, + }, + MedRecord, +}; +use itertools::Itertools; +use std::{ + collections::HashSet, + ops::{Add, Mul, Sub}, +}; + +#[derive(Debug, Clone)] +pub enum EdgeOperation { + Values { + operand: Wrapper, + }, + Attributes { + operand: Wrapper, + }, + Indices { + operand: Wrapper, + }, + + InGroup { + group: CardinalityWrapper, + }, + HasAttribute { + attribute: CardinalityWrapper, + }, + + SourceNode { + operand: Wrapper, + }, + TargetNode { + operand: Wrapper, + }, + + EitherOr { + either: Wrapper, + or: Wrapper, + }, +} + +impl DeepClone for EdgeOperation { + fn deep_clone(&self) -> Self { + match self { + Self::Values { operand } => Self::Values { + operand: operand.deep_clone(), + }, + Self::Attributes { operand } => Self::Attributes { + operand: operand.deep_clone(), + }, + Self::Indices { operand } => Self::Indices { + operand: operand.deep_clone(), + }, + Self::InGroup { group } => Self::InGroup { + group: group.clone(), + }, + Self::HasAttribute { attribute } => Self::HasAttribute { + attribute: attribute.clone(), + }, + Self::SourceNode { operand } => Self::SourceNode { + operand: operand.deep_clone(), + }, + Self::TargetNode { operand } => Self::TargetNode { + operand: operand.deep_clone(), + }, + Self::EitherOr { either, or } => Self::EitherOr { + either: either.deep_clone(), + or: or.deep_clone(), + }, + } + } +} + +impl EdgeOperation { + pub(crate) fn evaluate<'a>( + &self, + medrecord: &'a MedRecord, + edge_indices: impl Iterator + 'a, + ) -> MedRecordResult> { + Ok(match self { + Self::Values { operand } => Box::new(Self::evaluate_values( + medrecord, + edge_indices, + operand.clone(), + )?), + Self::Attributes { operand } => Box::new(Self::evaluate_attributes( + medrecord, + edge_indices, + operand.clone(), + )?), + Self::Indices { operand } => Box::new(Self::evaluate_indices( + medrecord, + edge_indices, + operand.clone(), + )?), + Self::InGroup { group } => Box::new(Self::evaluate_in_group( + medrecord, + edge_indices, + group.clone(), + )), + Self::HasAttribute { attribute } => Box::new(Self::evaluate_has_attribute( + medrecord, + edge_indices, + attribute.clone(), + )), + Self::SourceNode { operand } => Box::new(Self::evaluate_source_node( + medrecord, + edge_indices, + operand, + )?), + Self::TargetNode { operand } => Box::new(Self::evaluate_target_node( + medrecord, + edge_indices, + operand, + )?), + Self::EitherOr { either, or } => { + Box::new(Self::evaluate_either_or(medrecord, either, or)?) + } + }) + } + + #[inline] + pub(crate) fn get_values<'a>( + medrecord: &'a MedRecord, + edge_indices: impl Iterator, + attribute: MedRecordAttribute, + ) -> impl Iterator { + edge_indices.flat_map(move |edge_index| { + Some(( + edge_index, + medrecord + .edge_attributes(edge_index) + .expect("Edge must exist") + .get(&attribute)? + .clone(), + )) + }) + } + + #[inline] + fn evaluate_values<'a>( + medrecord: &'a MedRecord, + edge_indices: impl Iterator + 'a, + operand: Wrapper, + ) -> MedRecordResult> { + let values = Self::get_values( + medrecord, + edge_indices, + operand.0.read_or_panic().attribute.clone(), + ); + + Ok(operand.evaluate(medrecord, values)?.map(|value| value.0)) + } + + #[inline] + pub(crate) fn get_attributes<'a>( + medrecord: &'a MedRecord, + edge_indices: impl Iterator, + ) -> impl Iterator)> { + edge_indices.map(move |edge_index| { + let attributes = medrecord + .edge_attributes(edge_index) + .expect("Edge must exist") + .keys() + .cloned(); + + (edge_index, attributes.collect()) + }) + } + + #[inline] + fn evaluate_attributes<'a>( + medrecord: &'a MedRecord, + edge_indices: impl Iterator + 'a, + operand: Wrapper, + ) -> MedRecordResult> { + let attributes = Self::get_attributes(medrecord, edge_indices); + + Ok(operand + .evaluate(medrecord, attributes)? + .map(|value| value.0)) + } + + #[inline] + fn evaluate_indices<'a>( + medrecord: &MedRecord, + edge_indices: impl Iterator, + operand: Wrapper, + ) -> MedRecordResult> { + // TODO: This is a temporary solution. It should be optimized. + let edge_indices = edge_indices.collect::>(); + + let result = operand + .evaluate(medrecord, edge_indices.clone().into_iter().cloned())? + .collect::>(); + + Ok(edge_indices + .into_iter() + .filter(move |index| result.contains(index))) + } + + #[inline] + fn evaluate_in_group<'a>( + medrecord: &'a MedRecord, + edge_indices: impl Iterator, + group: CardinalityWrapper, + ) -> impl Iterator { + edge_indices.filter(move |edge_index| { + let groups_of_edge = medrecord + .groups_of_edge(edge_index) + .expect("Node must exist"); + + let groups_of_edge = groups_of_edge.collect::>(); + + match &group { + CardinalityWrapper::Single(group) => groups_of_edge.contains(&group), + CardinalityWrapper::Multiple(groups) => { + groups.iter().all(|group| groups_of_edge.contains(&group)) + } + } + }) + } + + #[inline] + fn evaluate_has_attribute<'a>( + medrecord: &'a MedRecord, + edge_indices: impl Iterator, + attribute: CardinalityWrapper, + ) -> impl Iterator { + edge_indices.filter(move |edge_index| { + let attributes_of_edge = medrecord + .edge_attributes(edge_index) + .expect("Node must exist") + .keys(); + + let attributes_of_edge = attributes_of_edge.collect::>(); + + match &attribute { + CardinalityWrapper::Single(attribute) => attributes_of_edge.contains(&attribute), + CardinalityWrapper::Multiple(attributes) => attributes + .iter() + .all(|attribute| attributes_of_edge.contains(&attribute)), + } + }) + } + + #[inline] + fn evaluate_source_node<'a>( + medrecord: &'a MedRecord, + edge_indices: impl Iterator, + operand: &Wrapper, + ) -> MedRecordResult> { + let node_indices = operand.evaluate(medrecord)?.collect::>(); + + Ok(edge_indices.filter(move |edge_index| { + let edge_endpoints = medrecord + .edge_endpoints(edge_index) + .expect("Edge must exist"); + + node_indices.contains(edge_endpoints.1) + })) + } + + #[inline] + fn evaluate_target_node<'a>( + medrecord: &'a MedRecord, + edge_indices: impl Iterator, + operand: &Wrapper, + ) -> MedRecordResult> { + let node_indices = operand.evaluate(medrecord)?.collect::>(); + + Ok(edge_indices.filter(move |edge_index| { + let edge_endpoints = medrecord + .edge_endpoints(edge_index) + .expect("Edge must exist"); + + node_indices.contains(edge_endpoints.1) + })) + } + + #[inline] + fn evaluate_either_or<'a>( + medrecord: &'a MedRecord, + either: &Wrapper, + or: &Wrapper, + ) -> MedRecordResult> { + let either_result = either.evaluate(medrecord)?; + let or_result = or.evaluate(medrecord)?; + + Ok(either_result.chain(or_result).unique()) + } +} + +macro_rules! get_edge_index { + ($kind:ident, $indices:expr) => { + match $kind { + SingleKind::Max => EdgeIndicesOperation::get_max($indices)?.clone(), + SingleKind::Min => EdgeIndicesOperation::get_min($indices)?.clone(), + SingleKind::Count => EdgeIndicesOperation::get_count($indices), + SingleKind::Sum => EdgeIndicesOperation::get_sum($indices), + SingleKind::First => EdgeIndicesOperation::get_first($indices)?, + SingleKind::Last => EdgeIndicesOperation::get_last($indices)?, + } + }; +} + +macro_rules! get_edge_index_comparison_operand_index { + ($operand:ident, $medrecord:ident) => { + match $operand { + EdgeIndexComparisonOperand::Operand(operand) => { + let context = &operand.context.context; + let kind = &operand.kind; + + // TODO: This is a temporary solution. It should be optimized. + let comparison_indices = context.evaluate($medrecord)?.cloned(); + + let comparison_index = get_edge_index!(kind, comparison_indices); + + comparison_index + } + EdgeIndexComparisonOperand::Index(index) => index.clone(), + } + }; +} + +#[derive(Debug, Clone)] +pub enum EdgeIndicesOperation { + EdgeIndexOperation { + operand: Wrapper, + }, + EdgeIndexComparisonOperation { + operand: EdgeIndexComparisonOperand, + kind: SingleComparisonKind, + }, + EdgeIndicesComparisonOperation { + operand: EdgeIndicesComparisonOperand, + kind: MultipleComparisonKind, + }, + BinaryArithmeticOpration { + operand: EdgeIndexComparisonOperand, + kind: BinaryArithmeticKind, + }, + + IsMax, + IsMin, + + EitherOr { + either: Wrapper, + or: Wrapper, + }, +} + +impl DeepClone for EdgeIndicesOperation { + fn deep_clone(&self) -> Self { + match self { + Self::EdgeIndexOperation { operand } => Self::EdgeIndexOperation { + operand: operand.deep_clone(), + }, + Self::EdgeIndexComparisonOperation { operand, kind } => { + Self::EdgeIndexComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::EdgeIndicesComparisonOperation { operand, kind } => { + Self::EdgeIndicesComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::BinaryArithmeticOpration { operand, kind } => Self::BinaryArithmeticOpration { + operand: operand.deep_clone(), + kind: kind.clone(), + }, + Self::IsMax => Self::IsMax, + Self::IsMin => Self::IsMin, + Self::EitherOr { either, or } => Self::EitherOr { + either: either.deep_clone(), + or: or.deep_clone(), + }, + } + } +} + +impl EdgeIndicesOperation { + pub(crate) fn evaluate<'a>( + &self, + medrecord: &'a MedRecord, + indices: impl Iterator + 'a, + ) -> MedRecordResult> { + match self { + Self::EdgeIndexOperation { operand } => { + Self::evaluate_edge_index_operation(medrecord, indices, operand) + } + Self::EdgeIndexComparisonOperation { operand, kind } => { + Self::evaluate_edge_index_comparison_operation(medrecord, indices, operand, kind) + } + Self::EdgeIndicesComparisonOperation { operand, kind } => { + Self::evaluate_edge_indices_comparison_operation(medrecord, indices, operand, kind) + } + Self::BinaryArithmeticOpration { operand, kind } => { + Ok(Box::new(Self::evaluate_binary_arithmetic_operation( + medrecord, + indices, + operand, + kind.clone(), + )?)) + } + Self::IsMax => { + let max_index = Self::get_max(indices)?; + + Ok(Box::new(std::iter::once(max_index))) + } + Self::IsMin => { + let min_index = Self::get_min(indices)?; + + Ok(Box::new(std::iter::once(min_index))) + } + Self::EitherOr { either, or } => { + Self::evaluate_either_or(medrecord, indices, either, or) + } + } + } + + #[inline] + pub(crate) fn get_max(indices: impl Iterator) -> MedRecordResult { + indices.max().ok_or(MedRecordError::QueryError( + "No indices to compare".to_string(), + )) + } + + #[inline] + pub(crate) fn get_min(indices: impl Iterator) -> MedRecordResult { + indices.min().ok_or(MedRecordError::QueryError( + "No indices to compare".to_string(), + )) + } + #[inline] + pub(crate) fn get_count(indices: impl Iterator) -> EdgeIndex { + indices.count() as EdgeIndex + } + + #[inline] + pub(crate) fn get_sum(indices: impl Iterator) -> EdgeIndex { + indices.sum() + } + + #[inline] + pub(crate) fn get_first( + mut indices: impl Iterator, + ) -> MedRecordResult { + indices.next().ok_or(MedRecordError::QueryError( + "No indices to get the first".to_string(), + )) + } + + #[inline] + pub(crate) fn get_last(indices: impl Iterator) -> MedRecordResult { + indices.last().ok_or(MedRecordError::QueryError( + "No indices to get the first".to_string(), + )) + } + + #[inline] + fn evaluate_edge_index_operation<'a>( + medrecord: &MedRecord, + indices: impl Iterator, + operand: &Wrapper, + ) -> MedRecordResult> { + let kind = &operand.0.read_or_panic().kind; + + let indices = indices.collect::>(); + + let index = get_edge_index!(kind, indices.clone().into_iter()); + + Ok(match operand.evaluate(medrecord, index)? { + Some(_) => Box::new(indices.into_iter()), + None => Box::new(std::iter::empty()), + }) + } + + #[inline] + fn evaluate_edge_index_comparison_operation<'a>( + medrecord: &MedRecord, + indices: impl Iterator + 'a, + comparison_operand: &EdgeIndexComparisonOperand, + kind: &SingleComparisonKind, + ) -> MedRecordResult> { + let comparison_index = + get_edge_index_comparison_operand_index!(comparison_operand, medrecord); + + match kind { + SingleComparisonKind::GreaterThan => Ok(Box::new( + indices.filter(move |index| index > &comparison_index), + )), + SingleComparisonKind::GreaterThanOrEqualTo => Ok(Box::new( + indices.filter(move |index| index >= &comparison_index), + )), + SingleComparisonKind::LessThan => Ok(Box::new( + indices.filter(move |index| index < &comparison_index), + )), + SingleComparisonKind::LessThanOrEqualTo => Ok(Box::new( + indices.filter(move |index| index <= &comparison_index), + )), + SingleComparisonKind::EqualTo => Ok(Box::new( + indices.filter(move |index| index == &comparison_index), + )), + SingleComparisonKind::NotEqualTo => Ok(Box::new( + indices.filter(move |index| index != &comparison_index), + )), + SingleComparisonKind::StartsWith => Ok(Box::new( + indices.filter(move |index| index.starts_with(&comparison_index)), + )), + SingleComparisonKind::EndsWith => Ok(Box::new( + indices.filter(move |index| index.ends_with(&comparison_index)), + )), + SingleComparisonKind::Contains => Ok(Box::new( + indices.filter(move |index| index.contains(&comparison_index)), + )), + } + } + + #[inline] + fn evaluate_edge_indices_comparison_operation<'a>( + medrecord: &MedRecord, + indices: impl Iterator + 'a, + comparison_operand: &EdgeIndicesComparisonOperand, + kind: &MultipleComparisonKind, + ) -> MedRecordResult> { + let comparison_indices = match comparison_operand { + EdgeIndicesComparisonOperand::Operand(operand) => { + let context = &operand.context; + + context.evaluate(medrecord)?.cloned().collect::>() + } + EdgeIndicesComparisonOperand::Indices(indices) => indices.clone(), + }; + + match kind { + MultipleComparisonKind::IsIn => Ok(Box::new( + indices.filter(move |index| comparison_indices.contains(index)), + )), + MultipleComparisonKind::IsNotIn => Ok(Box::new( + indices.filter(move |index| !comparison_indices.contains(index)), + )), + } + } + + #[inline] + fn evaluate_binary_arithmetic_operation( + medrecord: &MedRecord, + indices: impl Iterator, + operand: &EdgeIndexComparisonOperand, + kind: BinaryArithmeticKind, + ) -> MedRecordResult> { + let arithmetic_index = get_edge_index_comparison_operand_index!(operand, medrecord); + + Ok(indices + .map(move |index| match kind { + BinaryArithmeticKind::Add => Ok(index.add(arithmetic_index)), + BinaryArithmeticKind::Sub => Ok(index.sub(arithmetic_index)), + BinaryArithmeticKind::Mul => Ok(index.mul(arithmetic_index)), + BinaryArithmeticKind::Pow => Ok(index.pow(arithmetic_index)), + BinaryArithmeticKind::Mod => index.r#mod(arithmetic_index), + }) + .collect::>>()? + .into_iter()) + } + + #[inline] + fn evaluate_either_or<'a>( + medrecord: &'a MedRecord, + indices: impl Iterator, + either: &Wrapper, + or: &Wrapper, + ) -> MedRecordResult> { + let indices = indices.collect::>(); + + let either_indices = either.evaluate(medrecord, indices.clone().into_iter())?; + let or_indices = or.evaluate(medrecord, indices.into_iter())?; + + Ok(Box::new(either_indices.chain(or_indices).unique())) + } +} + +#[derive(Debug, Clone)] +pub enum EdgeIndexOperation { + EdgeIndexComparisonOperation { + operand: EdgeIndexComparisonOperand, + kind: SingleComparisonKind, + }, + EdgeIndicesComparisonOperation { + operand: EdgeIndicesComparisonOperand, + kind: MultipleComparisonKind, + }, + BinaryArithmeticOpration { + operand: EdgeIndexComparisonOperand, + kind: BinaryArithmeticKind, + }, + + EitherOr { + either: Wrapper, + or: Wrapper, + }, +} + +impl DeepClone for EdgeIndexOperation { + fn deep_clone(&self) -> Self { + match self { + Self::EdgeIndexComparisonOperation { operand, kind } => { + Self::EdgeIndexComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::EdgeIndicesComparisonOperation { operand, kind } => { + Self::EdgeIndicesComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::BinaryArithmeticOpration { operand, kind } => Self::BinaryArithmeticOpration { + operand: operand.deep_clone(), + kind: kind.clone(), + }, + Self::EitherOr { either, or } => Self::EitherOr { + either: either.deep_clone(), + or: or.deep_clone(), + }, + } + } +} + +impl EdgeIndexOperation { + pub(crate) fn evaluate( + &self, + medrecord: &MedRecord, + index: EdgeIndex, + ) -> MedRecordResult> { + match self { + Self::EdgeIndexComparisonOperation { operand, kind } => { + Self::evaluate_edge_index_comparison_operation(medrecord, index, operand, kind) + } + Self::EdgeIndicesComparisonOperation { operand, kind } => { + Self::evaluate_edge_indcies_comparison_operation(medrecord, index, operand, kind) + } + Self::BinaryArithmeticOpration { operand, kind } => { + Self::evaluate_binary_arithmetic_operation(medrecord, index, operand, kind) + } + Self::EitherOr { either, or } => Self::evaluate_either_or(medrecord, index, either, or), + } + } + + #[inline] + fn evaluate_edge_index_comparison_operation( + medrecord: &MedRecord, + index: EdgeIndex, + comparison_operand: &EdgeIndexComparisonOperand, + kind: &SingleComparisonKind, + ) -> MedRecordResult> { + let comparison_index = + get_edge_index_comparison_operand_index!(comparison_operand, medrecord); + + let comparison_result = match kind { + SingleComparisonKind::GreaterThan => index > comparison_index, + SingleComparisonKind::GreaterThanOrEqualTo => index >= comparison_index, + SingleComparisonKind::LessThan => index < comparison_index, + SingleComparisonKind::LessThanOrEqualTo => index <= comparison_index, + SingleComparisonKind::EqualTo => index == comparison_index, + SingleComparisonKind::NotEqualTo => index != comparison_index, + SingleComparisonKind::StartsWith => index.starts_with(&comparison_index), + SingleComparisonKind::EndsWith => index.ends_with(&comparison_index), + SingleComparisonKind::Contains => index.contains(&comparison_index), + }; + + Ok(if comparison_result { Some(index) } else { None }) + } + + #[inline] + fn evaluate_edge_indcies_comparison_operation( + medrecord: &MedRecord, + index: EdgeIndex, + comparison_operand: &EdgeIndicesComparisonOperand, + kind: &MultipleComparisonKind, + ) -> MedRecordResult> { + let comparison_indices = match comparison_operand { + EdgeIndicesComparisonOperand::Operand(operand) => { + let context = &operand.context; + + context.evaluate(medrecord)?.cloned().collect::>() + } + EdgeIndicesComparisonOperand::Indices(indices) => indices.clone(), + }; + + let comparison_result = match kind { + MultipleComparisonKind::IsIn => comparison_indices + .into_iter() + .any(|comparison_index| index == comparison_index), + MultipleComparisonKind::IsNotIn => comparison_indices + .into_iter() + .all(|comparison_index| index != comparison_index), + }; + + Ok(if comparison_result { Some(index) } else { None }) + } + + #[inline] + fn evaluate_binary_arithmetic_operation( + medrecord: &MedRecord, + index: EdgeIndex, + operand: &EdgeIndexComparisonOperand, + kind: &BinaryArithmeticKind, + ) -> MedRecordResult> { + let arithmetic_index = get_edge_index_comparison_operand_index!(operand, medrecord); + + Ok(Some(match kind { + BinaryArithmeticKind::Add => index.add(arithmetic_index), + BinaryArithmeticKind::Sub => index.sub(arithmetic_index), + BinaryArithmeticKind::Mul => index.mul(arithmetic_index), + BinaryArithmeticKind::Pow => index.pow(arithmetic_index), + BinaryArithmeticKind::Mod => index.r#mod(arithmetic_index)?, + })) + } + + #[inline] + fn evaluate_either_or( + medrecord: &MedRecord, + index: EdgeIndex, + either: &Wrapper, + or: &Wrapper, + ) -> MedRecordResult> { + let either_result = either.evaluate(medrecord, index)?; + let or_result = or.evaluate(medrecord, index)?; + + match (either_result, or_result) { + (Some(either_result), _) => Ok(Some(either_result)), + (None, Some(or_result)) => Ok(Some(or_result)), + _ => Ok(None), + } + } +} diff --git a/crates/medmodels-core/src/medrecord/querying/edges/selection.rs b/crates/medmodels-core/src/medrecord/querying/edges/selection.rs new file mode 100644 index 00000000..a0d0a519 --- /dev/null +++ b/crates/medmodels-core/src/medrecord/querying/edges/selection.rs @@ -0,0 +1,32 @@ +use super::EdgeOperand; +use crate::{ + errors::MedRecordResult, + medrecord::{querying::wrapper::Wrapper, EdgeIndex, MedRecord}, +}; + +#[derive(Debug, Clone)] +pub struct EdgeSelection<'a> { + medrecord: &'a MedRecord, + operand: Wrapper, +} + +impl<'a> EdgeSelection<'a> { + pub fn new(medrecord: &'a MedRecord, query: Q) -> Self + where + Q: FnOnce(&mut Wrapper), + { + let mut operand = Wrapper::::new(); + + query(&mut operand); + + Self { medrecord, operand } + } + + pub fn iter(&'a self) -> MedRecordResult> { + self.operand.evaluate(self.medrecord) + } + + pub fn collect>(&'a self) -> MedRecordResult { + Ok(FromIterator::from_iter(self.iter()?)) + } +} diff --git a/crates/medmodels-core/src/medrecord/querying/mod.rs b/crates/medmodels-core/src/medrecord/querying/mod.rs index 1f999f78..94728fe4 100644 --- a/crates/medmodels-core/src/medrecord/querying/mod.rs +++ b/crates/medmodels-core/src/medrecord/querying/mod.rs @@ -1,9 +1,8 @@ -mod operation; -mod selection; +pub mod attributes; +pub mod edges; +pub mod nodes; +mod traits; +pub mod values; +pub mod wrapper; -pub use self::operation::{ - edge, node, ArithmeticOperation, EdgeAttributeOperand, EdgeIndexOperand, EdgeOperand, - EdgeOperation, NodeAttributeOperand, NodeIndexOperand, NodeOperand, NodeOperation, - TransformationOperation, ValueOperand, -}; -pub(super) use self::selection::{EdgeSelection, NodeSelection}; +pub(crate) type BoxedIterator<'a, T> = Box + 'a>; diff --git a/crates/medmodels-core/src/medrecord/querying/nodes/mod.rs b/crates/medmodels-core/src/medrecord/querying/nodes/mod.rs new file mode 100644 index 00000000..1041a7e9 --- /dev/null +++ b/crates/medmodels-core/src/medrecord/querying/nodes/mod.rs @@ -0,0 +1,68 @@ +mod operand; +mod operation; +mod selection; + +pub use operand::NodeOperand; +pub use operation::NodeOperation; +pub use selection::NodeSelection; +use std::fmt::Display; + +#[derive(Debug, Clone)] +pub enum SingleKind { + Max, + Min, + Count, + Sum, + First, + Last, +} + +#[derive(Debug, Clone)] +pub enum SingleComparisonKind { + GreaterThan, + GreaterThanOrEqualTo, + LessThan, + LessThanOrEqualTo, + EqualTo, + NotEqualTo, + StartsWith, + EndsWith, + Contains, +} + +#[derive(Debug, Clone)] +pub enum MultipleComparisonKind { + IsIn, + IsNotIn, +} + +#[derive(Debug, Clone)] +pub enum BinaryArithmeticKind { + Add, + Sub, + Mul, + Pow, + Mod, +} + +impl Display for BinaryArithmeticKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + BinaryArithmeticKind::Add => write!(f, "add"), + BinaryArithmeticKind::Sub => write!(f, "sub"), + BinaryArithmeticKind::Mul => write!(f, "mul"), + BinaryArithmeticKind::Pow => write!(f, "pow"), + BinaryArithmeticKind::Mod => write!(f, "mod"), + } + } +} + +#[derive(Debug, Clone)] +pub enum UnaryArithmeticKind { + Abs, + Trim, + TrimStart, + TrimEnd, + Lowercase, + Uppercase, +} diff --git a/crates/medmodels-core/src/medrecord/querying/nodes/operand.rs b/crates/medmodels-core/src/medrecord/querying/nodes/operand.rs new file mode 100644 index 00000000..1800bc00 --- /dev/null +++ b/crates/medmodels-core/src/medrecord/querying/nodes/operand.rs @@ -0,0 +1,732 @@ +use super::{ + operation::{EdgeDirection, NodeIndexOperation, NodeIndicesOperation, NodeOperation}, + BinaryArithmeticKind, MultipleComparisonKind, SingleComparisonKind, SingleKind, + UnaryArithmeticKind, +}; +use crate::{ + errors::MedRecordResult, + medrecord::{ + querying::{ + attributes::{self, AttributesTreeOperand}, + edges::EdgeOperand, + traits::{DeepClone, ReadWriteOrPanic}, + values::{self, MultipleValuesOperand}, + wrapper::{CardinalityWrapper, Wrapper}, + BoxedIterator, + }, + Group, MedRecordAttribute, NodeIndex, + }, + MedRecord, +}; +use std::fmt::Debug; + +#[derive(Debug, Clone)] +pub struct NodeOperand { + operations: Vec, +} + +impl DeepClone for NodeOperand { + fn deep_clone(&self) -> Self { + Self { + operations: self + .operations + .iter() + .map(|operation| operation.deep_clone()) + .collect(), + } + } +} + +impl NodeOperand { + pub(crate) fn new() -> Self { + Self { + operations: Vec::new(), + } + } + + pub(crate) fn evaluate<'a>( + &self, + medrecord: &'a MedRecord, + ) -> MedRecordResult> { + let node_indices = Box::new(medrecord.node_indices()) as BoxedIterator<'a, &'a NodeIndex>; + + self.operations + .iter() + .try_fold(node_indices, |node_indices, operation| { + operation.evaluate(medrecord, node_indices) + }) + } + + pub fn attribute(&mut self, attribute: MedRecordAttribute) -> Wrapper { + let operand = Wrapper::::new( + values::Context::NodeOperand(self.deep_clone()), + attribute, + ); + + self.operations.push(NodeOperation::Values { + operand: operand.clone(), + }); + + operand + } + + pub fn attributes(&mut self) -> Wrapper { + let operand = Wrapper::::new(attributes::Context::NodeOperand( + self.deep_clone(), + )); + + self.operations.push(NodeOperation::Attributes { + operand: operand.clone(), + }); + + operand + } + + pub fn index(&mut self) -> Wrapper { + let operand = Wrapper::::new(self.deep_clone()); + + self.operations.push(NodeOperation::Indices { + operand: operand.clone(), + }); + + operand + } + + pub fn in_group(&mut self, group: G) + where + G: Into>, + { + self.operations.push(NodeOperation::InGroup { + group: group.into(), + }); + } + + pub fn has_attribute(&mut self, attribute: A) + where + A: Into>, + { + self.operations.push(NodeOperation::HasAttribute { + attribute: attribute.into(), + }); + } + + pub fn outgoing_edges(&mut self) -> Wrapper { + let operand = Wrapper::::new(); + + self.operations.push(NodeOperation::OutgoingEdges { + operand: operand.clone(), + }); + + operand + } + + pub fn incoming_edges(&mut self) -> Wrapper { + let operand = Wrapper::::new(); + + self.operations.push(NodeOperation::IncomingEdges { + operand: operand.clone(), + }); + + operand + } + + pub fn neighbors(&mut self, direction: EdgeDirection) -> Wrapper { + let operand = Wrapper::::new(); + + self.operations.push(NodeOperation::Neighbors { + operand: operand.clone(), + direction, + }); + + operand + } + + pub fn either_or(&mut self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + let mut either_operand = Wrapper::::new(); + let mut or_operand = Wrapper::::new(); + + either_query(&mut either_operand); + or_query(&mut or_operand); + + self.operations.push(NodeOperation::EitherOr { + either: either_operand, + or: or_operand, + }); + } +} + +impl Wrapper { + pub(crate) fn new() -> Self { + NodeOperand::new().into() + } + + pub(crate) fn evaluate<'a>( + &self, + medrecord: &'a MedRecord, + ) -> MedRecordResult> { + self.0.read_or_panic().evaluate(medrecord) + } + + pub fn attribute(&mut self, attribute: MedRecordAttribute) -> Wrapper { + self.0.write_or_panic().attribute(attribute) + } + + pub fn attributes(&mut self) -> Wrapper { + self.0.write_or_panic().attributes() + } + + pub fn index(&mut self) -> Wrapper { + self.0.write_or_panic().index() + } + + pub fn in_group(&mut self, group: G) + where + G: Into>, + { + self.0.write_or_panic().in_group(group); + } + + pub fn has_attribute(&mut self, attribute: A) + where + A: Into>, + { + self.0.write_or_panic().has_attribute(attribute); + } + + pub fn outgoing_edges(&mut self) -> Wrapper { + self.0.write_or_panic().outgoing_edges() + } + + pub fn incoming_edges(&mut self) -> Wrapper { + self.0.write_or_panic().incoming_edges() + } + + pub fn neighbors(&mut self, direction: EdgeDirection) -> Wrapper { + self.0.write_or_panic().neighbors(direction) + } + + pub fn either_or(&mut self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + self.0.write_or_panic().either_or(either_query, or_query); + } +} + +macro_rules! implement_value_operation { + ($name:ident, $variant:ident) => { + pub fn $name(&mut self) -> Wrapper { + let operand = Wrapper::::new(self.deep_clone(), SingleKind::$variant); + + self.operations + .push(NodeIndicesOperation::NodeIndexOperation { + operand: operand.clone(), + }); + + operand + } + }; +} + +macro_rules! implement_single_value_comparison_operation { + ($name:ident, $operation:ident, $kind:ident) => { + pub fn $name>(&mut self, value: V) { + self.operations + .push($operation::NodeIndexComparisonOperation { + operand: value.into(), + kind: SingleComparisonKind::$kind, + }); + } + }; +} + +macro_rules! implement_binary_arithmetic_operation { + ($name:ident, $operation:ident, $kind:ident) => { + pub fn $name>(&mut self, value: V) { + self.operations.push($operation::BinaryArithmeticOpration { + operand: value.into(), + kind: BinaryArithmeticKind::$kind, + }); + } + }; +} + +macro_rules! implement_unary_arithmetic_operation { + ($name:ident, $operation:ident, $kind:ident) => { + pub fn $name(&mut self) { + self.operations.push($operation::UnaryArithmeticOperation { + kind: UnaryArithmeticKind::$kind, + }); + } + }; +} + +macro_rules! implement_assertion_operation { + ($name:ident, $operation:expr) => { + pub fn $name(&mut self) { + self.operations.push($operation); + } + }; +} + +macro_rules! implement_wrapper_operand { + ($name:ident) => { + pub fn $name(&self) { + self.0.write_or_panic().$name() + } + }; +} + +macro_rules! implement_wrapper_operand_with_return { + ($name:ident, $return_operand:ident) => { + pub fn $name(&self) -> Wrapper<$return_operand> { + self.0.write_or_panic().$name() + } + }; +} + +macro_rules! implement_wrapper_operand_with_argument { + ($name:ident, $value_type:ty) => { + pub fn $name(&self, value: $value_type) { + self.0.write_or_panic().$name(value) + } + }; +} + +#[derive(Debug, Clone)] +pub enum NodeIndexComparisonOperand { + Operand(NodeIndexOperand), + Index(NodeIndex), +} + +impl DeepClone for NodeIndexComparisonOperand { + fn deep_clone(&self) -> Self { + match self { + Self::Operand(operand) => Self::Operand(operand.deep_clone()), + Self::Index(value) => Self::Index(value.clone()), + } + } +} + +impl From> for NodeIndexComparisonOperand { + fn from(value: Wrapper) -> Self { + Self::Operand(value.0.read_or_panic().deep_clone()) + } +} + +impl From<&Wrapper> for NodeIndexComparisonOperand { + fn from(value: &Wrapper) -> Self { + Self::Operand(value.0.read_or_panic().deep_clone()) + } +} + +impl> From for NodeIndexComparisonOperand { + fn from(value: V) -> Self { + Self::Index(value.into()) + } +} + +#[derive(Debug, Clone)] +pub enum NodeIndicesComparisonOperand { + Operand(NodeIndicesOperand), + Indices(Vec), +} + +impl DeepClone for NodeIndicesComparisonOperand { + fn deep_clone(&self) -> Self { + match self { + Self::Operand(operand) => Self::Operand(operand.deep_clone()), + Self::Indices(value) => Self::Indices(value.clone()), + } + } +} + +impl From> for NodeIndicesComparisonOperand { + fn from(value: Wrapper) -> Self { + Self::Operand(value.0.read_or_panic().deep_clone()) + } +} + +impl From<&Wrapper> for NodeIndicesComparisonOperand { + fn from(value: &Wrapper) -> Self { + Self::Operand(value.0.read_or_panic().deep_clone()) + } +} + +impl> From> for NodeIndicesComparisonOperand { + fn from(value: Vec) -> Self { + Self::Indices(value.into_iter().map(Into::into).collect()) + } +} + +impl + Clone, const N: usize> From<[V; N]> for NodeIndicesComparisonOperand { + fn from(value: [V; N]) -> Self { + value.to_vec().into() + } +} + +#[derive(Debug, Clone)] +pub struct NodeIndicesOperand { + pub(crate) context: NodeOperand, + operations: Vec, +} + +impl DeepClone for NodeIndicesOperand { + fn deep_clone(&self) -> Self { + Self { + context: self.context.clone(), + operations: self.operations.iter().map(DeepClone::deep_clone).collect(), + } + } +} + +impl NodeIndicesOperand { + pub(crate) fn new(context: NodeOperand) -> Self { + Self { + context, + operations: Vec::new(), + } + } + + pub(crate) fn evaluate<'a>( + &self, + medrecord: &'a MedRecord, + values: impl Iterator + 'a, + ) -> MedRecordResult + 'a> { + let values = Box::new(values) as BoxedIterator; + + self.operations + .iter() + .try_fold(values, |value_tuples, operation| { + operation.evaluate(medrecord, value_tuples) + }) + } + + implement_value_operation!(max, Max); + implement_value_operation!(min, Min); + implement_value_operation!(count, Count); + implement_value_operation!(sum, Sum); + implement_value_operation!(first, First); + implement_value_operation!(last, Last); + + implement_single_value_comparison_operation!(greater_than, NodeIndicesOperation, GreaterThan); + implement_single_value_comparison_operation!( + greater_than_or_equal_to, + NodeIndicesOperation, + GreaterThanOrEqualTo + ); + implement_single_value_comparison_operation!(less_than, NodeIndicesOperation, LessThan); + implement_single_value_comparison_operation!( + less_than_or_equal_to, + NodeIndicesOperation, + LessThanOrEqualTo + ); + implement_single_value_comparison_operation!(equal_to, NodeIndicesOperation, EqualTo); + implement_single_value_comparison_operation!(not_equal_to, NodeIndicesOperation, NotEqualTo); + implement_single_value_comparison_operation!(starts_with, NodeIndicesOperation, StartsWith); + implement_single_value_comparison_operation!(ends_with, NodeIndicesOperation, EndsWith); + implement_single_value_comparison_operation!(contains, NodeIndicesOperation, Contains); + + pub fn is_in>(&mut self, values: V) { + self.operations + .push(NodeIndicesOperation::NodeIndicesComparisonOperation { + operand: values.into(), + kind: MultipleComparisonKind::IsIn, + }); + } + + pub fn is_not_in>(&mut self, values: V) { + self.operations + .push(NodeIndicesOperation::NodeIndicesComparisonOperation { + operand: values.into(), + kind: MultipleComparisonKind::IsNotIn, + }); + } + + implement_binary_arithmetic_operation!(add, NodeIndicesOperation, Add); + implement_binary_arithmetic_operation!(sub, NodeIndicesOperation, Sub); + implement_binary_arithmetic_operation!(mul, NodeIndicesOperation, Mul); + implement_binary_arithmetic_operation!(pow, NodeIndicesOperation, Pow); + implement_binary_arithmetic_operation!(r#mod, NodeIndicesOperation, Mod); + + implement_unary_arithmetic_operation!(abs, NodeIndicesOperation, Abs); + implement_unary_arithmetic_operation!(trim, NodeIndicesOperation, Trim); + implement_unary_arithmetic_operation!(trim_start, NodeIndicesOperation, TrimStart); + implement_unary_arithmetic_operation!(trim_end, NodeIndicesOperation, TrimEnd); + implement_unary_arithmetic_operation!(lowercase, NodeIndicesOperation, Lowercase); + implement_unary_arithmetic_operation!(uppercase, NodeIndicesOperation, Uppercase); + + pub fn slice(&mut self, start: usize, end: usize) { + self.operations + .push(NodeIndicesOperation::Slice(start..end)); + } + + implement_assertion_operation!(is_string, NodeIndicesOperation::IsString); + implement_assertion_operation!(is_int, NodeIndicesOperation::IsInt); + implement_assertion_operation!(is_max, NodeIndicesOperation::IsMax); + implement_assertion_operation!(is_min, NodeIndicesOperation::IsMin); + + pub fn either_or(&mut self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + let mut either_operand = Wrapper::::new(self.context.clone()); + let mut or_operand = Wrapper::::new(self.context.clone()); + + either_query(&mut either_operand); + or_query(&mut or_operand); + + self.operations.push(NodeIndicesOperation::EitherOr { + either: either_operand, + or: or_operand, + }); + } +} + +impl Wrapper { + pub(crate) fn new(context: NodeOperand) -> Self { + NodeIndicesOperand::new(context).into() + } + + pub(crate) fn evaluate<'a>( + &self, + medrecord: &'a MedRecord, + values: impl Iterator + 'a, + ) -> MedRecordResult + 'a> { + self.0.read_or_panic().evaluate(medrecord, values) + } + + implement_wrapper_operand_with_return!(max, NodeIndexOperand); + implement_wrapper_operand_with_return!(min, NodeIndexOperand); + implement_wrapper_operand_with_return!(count, NodeIndexOperand); + implement_wrapper_operand_with_return!(sum, NodeIndexOperand); + implement_wrapper_operand_with_return!(first, NodeIndexOperand); + implement_wrapper_operand_with_return!(last, NodeIndexOperand); + + implement_wrapper_operand_with_argument!(greater_than, impl Into); + implement_wrapper_operand_with_argument!( + greater_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(less_than, impl Into); + implement_wrapper_operand_with_argument!( + less_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(equal_to, impl Into); + implement_wrapper_operand_with_argument!(not_equal_to, impl Into); + implement_wrapper_operand_with_argument!(starts_with, impl Into); + implement_wrapper_operand_with_argument!(ends_with, impl Into); + implement_wrapper_operand_with_argument!(contains, impl Into); + implement_wrapper_operand_with_argument!(is_in, impl Into); + implement_wrapper_operand_with_argument!(is_not_in, impl Into); + implement_wrapper_operand_with_argument!(add, impl Into); + implement_wrapper_operand_with_argument!(sub, impl Into); + implement_wrapper_operand_with_argument!(mul, impl Into); + implement_wrapper_operand_with_argument!(pow, impl Into); + implement_wrapper_operand_with_argument!(r#mod, impl Into); + + implement_wrapper_operand!(abs); + implement_wrapper_operand!(trim); + implement_wrapper_operand!(trim_start); + implement_wrapper_operand!(trim_end); + implement_wrapper_operand!(lowercase); + implement_wrapper_operand!(uppercase); + + pub fn slice(&self, start: usize, end: usize) { + self.0.write_or_panic().slice(start, end) + } + + implement_wrapper_operand!(is_string); + implement_wrapper_operand!(is_int); + implement_wrapper_operand!(is_max); + implement_wrapper_operand!(is_min); + + pub fn either_or(&self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + self.0.write_or_panic().either_or(either_query, or_query); + } +} + +#[derive(Debug, Clone)] +pub struct NodeIndexOperand { + pub(crate) context: NodeIndicesOperand, + pub(crate) kind: SingleKind, + operations: Vec, +} + +impl DeepClone for NodeIndexOperand { + fn deep_clone(&self) -> Self { + Self { + context: self.context.deep_clone(), + kind: self.kind.clone(), + operations: self.operations.iter().map(DeepClone::deep_clone).collect(), + } + } +} + +impl NodeIndexOperand { + pub(crate) fn new(context: NodeIndicesOperand, kind: SingleKind) -> Self { + Self { + context, + kind, + operations: Vec::new(), + } + } + + pub(crate) fn evaluate( + &self, + medrecord: &MedRecord, + value: NodeIndex, + ) -> MedRecordResult> { + self.operations + .iter() + .try_fold(Some(value), |value, operation| { + if let Some(value) = value { + operation.evaluate(medrecord, value) + } else { + Ok(None) + } + }) + } + + implement_single_value_comparison_operation!(greater_than, NodeIndexOperation, GreaterThan); + implement_single_value_comparison_operation!( + greater_than_or_equal_to, + NodeIndexOperation, + GreaterThanOrEqualTo + ); + implement_single_value_comparison_operation!(less_than, NodeIndexOperation, LessThan); + implement_single_value_comparison_operation!( + less_than_or_equal_to, + NodeIndexOperation, + LessThanOrEqualTo + ); + implement_single_value_comparison_operation!(equal_to, NodeIndexOperation, EqualTo); + implement_single_value_comparison_operation!(not_equal_to, NodeIndexOperation, NotEqualTo); + implement_single_value_comparison_operation!(starts_with, NodeIndexOperation, StartsWith); + implement_single_value_comparison_operation!(ends_with, NodeIndexOperation, EndsWith); + implement_single_value_comparison_operation!(contains, NodeIndexOperation, Contains); + + pub fn is_in>(&mut self, values: V) { + self.operations + .push(NodeIndexOperation::NodeIndicesComparisonOperation { + operand: values.into(), + kind: MultipleComparisonKind::IsIn, + }); + } + + pub fn is_not_in>(&mut self, values: V) { + self.operations + .push(NodeIndexOperation::NodeIndicesComparisonOperation { + operand: values.into(), + kind: MultipleComparisonKind::IsNotIn, + }); + } + + implement_binary_arithmetic_operation!(add, NodeIndexOperation, Add); + implement_binary_arithmetic_operation!(sub, NodeIndexOperation, Sub); + implement_binary_arithmetic_operation!(mul, NodeIndexOperation, Mul); + implement_binary_arithmetic_operation!(pow, NodeIndexOperation, Pow); + implement_binary_arithmetic_operation!(r#mod, NodeIndexOperation, Mod); + + implement_unary_arithmetic_operation!(abs, NodeIndexOperation, Abs); + implement_unary_arithmetic_operation!(trim, NodeIndexOperation, Trim); + implement_unary_arithmetic_operation!(trim_start, NodeIndexOperation, TrimStart); + implement_unary_arithmetic_operation!(trim_end, NodeIndexOperation, TrimEnd); + implement_unary_arithmetic_operation!(lowercase, NodeIndexOperation, Lowercase); + implement_unary_arithmetic_operation!(uppercase, NodeIndexOperation, Uppercase); + + pub fn slice(&mut self, start: usize, end: usize) { + self.operations.push(NodeIndexOperation::Slice(start..end)); + } + + implement_assertion_operation!(is_string, NodeIndexOperation::IsString); + implement_assertion_operation!(is_int, NodeIndexOperation::IsInt); + pub fn eiter_or(&mut self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + let mut either_operand = + Wrapper::::new(self.context.clone(), self.kind.clone()); + let mut or_operand = + Wrapper::::new(self.context.clone(), self.kind.clone()); + + either_query(&mut either_operand); + or_query(&mut or_operand); + + self.operations.push(NodeIndexOperation::EitherOr { + either: either_operand, + or: or_operand, + }); + } +} + +impl Wrapper { + pub(crate) fn new(context: NodeIndicesOperand, kind: SingleKind) -> Self { + NodeIndexOperand::new(context, kind).into() + } + + pub(crate) fn evaluate( + &self, + medrecord: &MedRecord, + value: NodeIndex, + ) -> MedRecordResult> { + self.0.read_or_panic().evaluate(medrecord, value) + } + + implement_wrapper_operand_with_argument!(greater_than, impl Into); + implement_wrapper_operand_with_argument!( + greater_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(less_than, impl Into); + implement_wrapper_operand_with_argument!( + less_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(equal_to, impl Into); + implement_wrapper_operand_with_argument!(not_equal_to, impl Into); + implement_wrapper_operand_with_argument!(starts_with, impl Into); + implement_wrapper_operand_with_argument!(ends_with, impl Into); + implement_wrapper_operand_with_argument!(contains, impl Into); + implement_wrapper_operand_with_argument!(is_in, impl Into); + implement_wrapper_operand_with_argument!(is_not_in, impl Into); + implement_wrapper_operand_with_argument!(add, impl Into); + implement_wrapper_operand_with_argument!(sub, impl Into); + implement_wrapper_operand_with_argument!(mul, impl Into); + implement_wrapper_operand_with_argument!(pow, impl Into); + implement_wrapper_operand_with_argument!(r#mod, impl Into); + + implement_wrapper_operand!(abs); + implement_wrapper_operand!(trim); + implement_wrapper_operand!(trim_start); + implement_wrapper_operand!(trim_end); + implement_wrapper_operand!(lowercase); + implement_wrapper_operand!(uppercase); + + pub fn slice(&self, start: usize, end: usize) { + self.0.write_or_panic().slice(start, end) + } + + implement_wrapper_operand!(is_string); + implement_wrapper_operand!(is_int); + + pub fn either_or(&self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + self.0.write_or_panic().eiter_or(either_query, or_query); + } +} diff --git a/crates/medmodels-core/src/medrecord/querying/nodes/operation.rs b/crates/medmodels-core/src/medrecord/querying/nodes/operation.rs new file mode 100644 index 00000000..90e6692c --- /dev/null +++ b/crates/medmodels-core/src/medrecord/querying/nodes/operation.rs @@ -0,0 +1,971 @@ +use super::{ + operand::{ + NodeIndexComparisonOperand, NodeIndexOperand, NodeIndicesComparisonOperand, + NodeIndicesOperand, + }, + BinaryArithmeticKind, MultipleComparisonKind, NodeOperand, SingleComparisonKind, SingleKind, + UnaryArithmeticKind, +}; +use crate::{ + errors::{MedRecordError, MedRecordResult}, + medrecord::{ + datatypes::{ + Abs, Contains, EndsWith, Lowercase, Mod, Pow, Slice, StartsWith, Trim, TrimEnd, + TrimStart, Uppercase, + }, + querying::{ + attributes::AttributesTreeOperand, + edges::EdgeOperand, + traits::{DeepClone, ReadWriteOrPanic}, + values::MultipleValuesOperand, + wrapper::{CardinalityWrapper, Wrapper}, + BoxedIterator, + }, + DataType, Group, MedRecord, MedRecordAttribute, MedRecordValue, NodeIndex, + }, +}; +use itertools::Itertools; +use roaring::RoaringBitmap; +use std::{ + cmp::Ordering, + collections::HashSet, + ops::{Add, Mul, Range, Sub}, +}; + +#[derive(Debug, Clone)] +pub enum EdgeDirection { + Incoming, + Outgoing, + Both, +} + +#[derive(Debug, Clone)] +pub enum NodeOperation { + Values { + operand: Wrapper, + }, + Attributes { + operand: Wrapper, + }, + Indices { + operand: Wrapper, + }, + + InGroup { + group: CardinalityWrapper, + }, + HasAttribute { + attribute: CardinalityWrapper, + }, + + OutgoingEdges { + operand: Wrapper, + }, + IncomingEdges { + operand: Wrapper, + }, + + Neighbors { + operand: Wrapper, + direction: EdgeDirection, + }, + + EitherOr { + either: Wrapper, + or: Wrapper, + }, +} + +impl DeepClone for NodeOperation { + fn deep_clone(&self) -> Self { + match self { + Self::Values { operand } => Self::Values { + operand: operand.deep_clone(), + }, + Self::Attributes { operand } => Self::Attributes { + operand: operand.deep_clone(), + }, + Self::Indices { operand } => Self::Indices { + operand: operand.deep_clone(), + }, + Self::InGroup { group } => Self::InGroup { + group: group.clone(), + }, + Self::HasAttribute { attribute } => Self::HasAttribute { + attribute: attribute.clone(), + }, + Self::OutgoingEdges { operand } => Self::OutgoingEdges { + operand: operand.deep_clone(), + }, + Self::IncomingEdges { operand } => Self::IncomingEdges { + operand: operand.deep_clone(), + }, + Self::Neighbors { + operand, + direction: drection, + } => Self::Neighbors { + operand: operand.deep_clone(), + direction: drection.clone(), + }, + Self::EitherOr { either, or } => Self::EitherOr { + either: either.deep_clone(), + or: or.deep_clone(), + }, + } + } +} + +impl NodeOperation { + pub(crate) fn evaluate<'a>( + &self, + medrecord: &'a MedRecord, + node_indices: impl Iterator + 'a, + ) -> MedRecordResult> { + Ok(match self { + Self::Values { operand } => Box::new(Self::evaluate_values( + medrecord, + node_indices, + operand.clone(), + )?), + Self::Attributes { operand } => Box::new(Self::evaluate_attributes( + medrecord, + node_indices, + operand.clone(), + )?), + Self::Indices { operand } => Box::new(Self::evaluate_indices( + medrecord, + node_indices, + operand.clone(), + )?), + Self::InGroup { group } => Box::new(Self::evaluate_in_group( + medrecord, + node_indices, + group.clone(), + )), + Self::HasAttribute { attribute } => Box::new(Self::evaluate_has_attribute( + medrecord, + node_indices, + attribute.clone(), + )), + Self::OutgoingEdges { operand } => Box::new(Self::evaluate_outgoing_edges( + medrecord, + node_indices, + operand.clone(), + )?), + Self::IncomingEdges { operand } => Box::new(Self::evaluate_incoming_edges( + medrecord, + node_indices, + operand.clone(), + )?), + Self::Neighbors { + operand, + direction: drection, + } => Box::new(Self::evaluate_neighbors( + medrecord, + node_indices, + operand.clone(), + drection.clone(), + )?), + Self::EitherOr { either, or } => { + // TODO: This is a temporary solution. It should be optimized. + let either_result = either.evaluate(medrecord)?.collect::>(); + let or_result = or.evaluate(medrecord)?.collect::>(); + + Box::new(either_result.into_iter().chain(or_result).unique()) + } + }) + } + + #[inline] + pub(crate) fn get_values<'a>( + medrecord: &'a MedRecord, + node_indices: impl Iterator, + attribute: MedRecordAttribute, + ) -> impl Iterator { + node_indices.flat_map(move |node_index| { + Some(( + node_index, + medrecord + .node_attributes(node_index) + .expect("Edge must exist") + .get(&attribute)? + .clone(), + )) + }) + } + + #[inline] + fn evaluate_values<'a>( + medrecord: &'a MedRecord, + node_indices: impl Iterator + 'a, + operand: Wrapper, + ) -> MedRecordResult> { + let values = Self::get_values( + medrecord, + node_indices, + operand.0.read_or_panic().attribute.clone(), + ); + + Ok(operand.evaluate(medrecord, values)?.map(|value| value.0)) + } + + #[inline] + pub(crate) fn get_attributes<'a>( + medrecord: &'a MedRecord, + node_indices: impl Iterator, + ) -> impl Iterator)> { + node_indices.map(move |node_index| { + let attributes = medrecord + .node_attributes(node_index) + .expect("Edge must exist") + .keys() + .cloned(); + + (node_index, attributes.collect()) + }) + } + + #[inline] + fn evaluate_attributes<'a>( + medrecord: &'a MedRecord, + node_indices: impl Iterator + 'a, + operand: Wrapper, + ) -> MedRecordResult> { + let attributes = Self::get_attributes(medrecord, node_indices); + + Ok(operand + .evaluate(medrecord, attributes)? + .map(|value| value.0)) + } + + #[inline] + fn evaluate_indices<'a>( + medrecord: &MedRecord, + edge_indices: impl Iterator, + operand: Wrapper, + ) -> MedRecordResult> { + // TODO: This is a temporary solution. It should be optimized. + let node_indices = edge_indices.collect::>(); + + let result = operand + .evaluate(medrecord, node_indices.clone().into_iter().cloned())? + .collect::>(); + + Ok(node_indices + .into_iter() + .filter(move |index| result.contains(index))) + } + + #[inline] + fn evaluate_in_group<'a>( + medrecord: &'a MedRecord, + node_indices: impl Iterator, + group: CardinalityWrapper, + ) -> impl Iterator { + node_indices.filter(move |node_index| { + let groups_of_node = medrecord + .groups_of_node(node_index) + .expect("Node must exist"); + + let groups_of_node = groups_of_node.collect::>(); + + match &group { + CardinalityWrapper::Single(group) => groups_of_node.contains(&group), + CardinalityWrapper::Multiple(groups) => { + groups.iter().all(|group| groups_of_node.contains(&group)) + } + } + }) + } + + #[inline] + fn evaluate_has_attribute<'a>( + medrecord: &'a MedRecord, + node_indices: impl Iterator, + attribute: CardinalityWrapper, + ) -> impl Iterator { + node_indices.filter(move |node_index| { + let attributes_of_node = medrecord + .node_attributes(node_index) + .expect("Node must exist") + .keys(); + + let attributes_of_node = attributes_of_node.collect::>(); + + match &attribute { + CardinalityWrapper::Single(attribute) => attributes_of_node.contains(&attribute), + CardinalityWrapper::Multiple(attributes) => attributes + .iter() + .all(|attribute| attributes_of_node.contains(&attribute)), + } + }) + } + + #[inline] + fn evaluate_outgoing_edges<'a>( + medrecord: &'a MedRecord, + node_indices: impl Iterator, + operand: Wrapper, + ) -> MedRecordResult> { + let edge_indices = operand.evaluate(medrecord)?.collect::(); + + Ok(node_indices.filter(move |node_index| { + let outgoing_edge_indices = medrecord + .outgoing_edges(node_index) + .expect("Node must exist"); + + let outgoing_edge_indices = outgoing_edge_indices.collect::(); + + !outgoing_edge_indices.is_disjoint(&edge_indices) + })) + } + + #[inline] + fn evaluate_incoming_edges<'a>( + medrecord: &'a MedRecord, + node_indices: impl Iterator, + operand: Wrapper, + ) -> MedRecordResult> { + let edge_indices = operand.evaluate(medrecord)?.collect::(); + + Ok(node_indices.filter(move |node_index| { + let incoming_edge_indices = medrecord + .incoming_edges(node_index) + .expect("Node must exist"); + + let incoming_edge_indices = incoming_edge_indices.collect::(); + + !incoming_edge_indices.is_disjoint(&edge_indices) + })) + } + + #[inline] + fn evaluate_neighbors<'a>( + medrecord: &'a MedRecord, + node_indices: impl Iterator, + operand: Wrapper, + direction: EdgeDirection, + ) -> MedRecordResult> { + let result = operand.evaluate(medrecord)?.collect::>(); + + Ok(node_indices.filter(move |node_index| { + let mut neighbors: Box> = match direction { + EdgeDirection::Incoming => Box::new( + medrecord + .neighbors_incoming(node_index) + .expect("Node must exist"), + ), + EdgeDirection::Outgoing => Box::new( + medrecord + .neighbors_outgoing(node_index) + .expect("Node must exist"), + ), + EdgeDirection::Both => Box::new( + medrecord + .neighbors_undirected(node_index) + .expect("Node must exist"), + ), + }; + + neighbors.any(|neighbor| result.contains(&neighbor)) + })) + } +} + +macro_rules! get_node_index { + ($kind:ident, $indices:expr) => { + match $kind { + SingleKind::Max => NodeIndicesOperation::get_max($indices)?.clone(), + SingleKind::Min => NodeIndicesOperation::get_min($indices)?.clone(), + SingleKind::Count => NodeIndicesOperation::get_count($indices), + SingleKind::Sum => NodeIndicesOperation::get_sum($indices)?, + SingleKind::First => NodeIndicesOperation::get_first($indices)?, + SingleKind::Last => NodeIndicesOperation::get_last($indices)?, + } + }; +} + +macro_rules! get_node_index_comparison_operand { + ($operand:ident, $medrecord:ident) => { + match $operand { + NodeIndexComparisonOperand::Operand(operand) => { + let context = &operand.context.context; + let kind = &operand.kind; + + // TODO: This is a temporary solution. It should be optimized. + let comparison_indices = context.evaluate($medrecord)?.cloned(); + + let comparison_index = get_node_index!(kind, comparison_indices); + + comparison_index + } + NodeIndexComparisonOperand::Index(index) => index.clone(), + } + }; +} + +#[derive(Debug, Clone)] +pub enum NodeIndicesOperation { + NodeIndexOperation { + operand: Wrapper, + }, + NodeIndexComparisonOperation { + operand: NodeIndexComparisonOperand, + kind: SingleComparisonKind, + }, + NodeIndicesComparisonOperation { + operand: NodeIndicesComparisonOperand, + kind: MultipleComparisonKind, + }, + BinaryArithmeticOpration { + operand: NodeIndexComparisonOperand, + kind: BinaryArithmeticKind, + }, + UnaryArithmeticOperation { + kind: UnaryArithmeticKind, + }, + + Slice(Range), + + IsString, + IsInt, + + IsMax, + IsMin, + + EitherOr { + either: Wrapper, + or: Wrapper, + }, +} + +impl DeepClone for NodeIndicesOperation { + fn deep_clone(&self) -> Self { + match self { + Self::NodeIndexOperation { operand } => Self::NodeIndexOperation { + operand: operand.deep_clone(), + }, + Self::NodeIndexComparisonOperation { operand, kind } => { + Self::NodeIndexComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::NodeIndicesComparisonOperation { operand, kind } => { + Self::NodeIndicesComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::BinaryArithmeticOpration { operand, kind } => Self::BinaryArithmeticOpration { + operand: operand.deep_clone(), + kind: kind.clone(), + }, + Self::UnaryArithmeticOperation { kind } => { + Self::UnaryArithmeticOperation { kind: kind.clone() } + } + Self::Slice(range) => Self::Slice(range.clone()), + Self::IsString => Self::IsString, + Self::IsInt => Self::IsInt, + Self::IsMax => Self::IsMax, + Self::IsMin => Self::IsMin, + Self::EitherOr { either, or } => Self::EitherOr { + either: either.deep_clone(), + or: or.deep_clone(), + }, + } + } +} + +impl NodeIndicesOperation { + pub(crate) fn evaluate<'a>( + &self, + medrecord: &'a MedRecord, + indices: impl Iterator + 'a, + ) -> MedRecordResult> { + match self { + Self::NodeIndexOperation { operand } => { + Self::evaluate_node_index_operation(medrecord, indices, operand) + } + Self::NodeIndexComparisonOperation { operand, kind } => { + Self::evaluate_node_index_comparison_operation(medrecord, indices, operand, kind) + } + Self::NodeIndicesComparisonOperation { operand, kind } => { + Self::evaluate_node_indices_comparison_operation(medrecord, indices, operand, kind) + } + Self::BinaryArithmeticOpration { operand, kind } => { + Ok(Box::new(Self::evaluate_binary_arithmetic_operation( + medrecord, + indices, + operand, + kind.clone(), + )?)) + } + Self::UnaryArithmeticOperation { kind } => Ok(Box::new( + Self::evaluate_unary_arithmetic_operation(indices, kind.clone()), + )), + Self::Slice(range) => Ok(Box::new(Self::evaluate_slice(indices, range.clone()))), + Self::IsString => { + Ok(Box::new(indices.filter(|index| { + matches!(index, MedRecordAttribute::String(_)) + }))) + } + Self::IsInt => { + Ok(Box::new(indices.filter(|index| { + matches!(index, MedRecordAttribute::Int(_)) + }))) + } + Self::IsMax => { + let max_index = Self::get_max(indices)?; + + Ok(Box::new(std::iter::once(max_index))) + } + Self::IsMin => { + let min_index = Self::get_min(indices)?; + + Ok(Box::new(std::iter::once(min_index))) + } + Self::EitherOr { either, or } => { + Self::evaluate_either_or(medrecord, indices, either, or) + } + } + } + + #[inline] + pub(crate) fn get_max( + mut indices: impl Iterator, + ) -> MedRecordResult { + let max_index = indices.next().ok_or(MedRecordError::QueryError( + "No indices to compare".to_string(), + ))?; + + indices.try_fold(max_index, |max_index, index| { + match index + .partial_cmp(&max_index) { + Some(Ordering::Greater) => Ok(index), + None => { + let first_dtype = DataType::from(index); + let second_dtype = DataType::from(max_index); + + Err(MedRecordError::QueryError(format!( + "Cannot compare indices of data types {} and {}. Consider narrowing down the indices using .is_string() or .is_int()", + first_dtype, second_dtype + ))) + } + _ => Ok(max_index), + } + }) + } + + #[inline] + pub(crate) fn get_min( + mut indices: impl Iterator, + ) -> MedRecordResult { + let min_index = indices.next().ok_or(MedRecordError::QueryError( + "No indices to compare".to_string(), + ))?; + + indices.try_fold(min_index, |min_index, index| { + match index.partial_cmp(&min_index) { + Some(Ordering::Less) => Ok(index), + None => { + let first_dtype = DataType::from(index); + let second_dtype = DataType::from(min_index); + + Err(MedRecordError::QueryError(format!( + "Cannot compare indices of data types {} and {}. Consider narrowing down the indices using .is_string() or .is_int()", + first_dtype, second_dtype + ))) + } + _ => Ok(min_index), + } + }) + } + #[inline] + pub(crate) fn get_count(indices: impl Iterator) -> NodeIndex { + MedRecordAttribute::Int(indices.count() as i64) + } + + #[inline] + // 🥊💥 + pub(crate) fn get_sum( + mut indices: impl Iterator, + ) -> MedRecordResult { + let first_value = indices + .next() + .ok_or(MedRecordError::QueryError("No indices to sum".to_string()))?; + + indices.try_fold(first_value, |sum, index| { + let first_dtype = DataType::from(&sum); + let second_dtype = DataType::from(&index); + + sum.add(index).map_err(|_| { + MedRecordError::QueryError(format!( + "Cannot add indices of data types {} and {}. Consider narrowing down the indices using .is_string() or .is_int()", + first_dtype, second_dtype + )) + }) + }) + } + + #[inline] + pub(crate) fn get_first( + mut indices: impl Iterator, + ) -> MedRecordResult { + indices.next().ok_or(MedRecordError::QueryError( + "No indices to get the first".to_string(), + )) + } + + #[inline] + pub(crate) fn get_last(indices: impl Iterator) -> MedRecordResult { + indices.last().ok_or(MedRecordError::QueryError( + "No indices to get the first".to_string(), + )) + } + + #[inline] + fn evaluate_node_index_operation<'a>( + medrecord: &MedRecord, + indices: impl Iterator, + operand: &Wrapper, + ) -> MedRecordResult> { + let kind = &operand.0.read_or_panic().kind; + + let indices = indices.collect::>(); + + let index = get_node_index!(kind, indices.clone().into_iter()); + + Ok(match operand.evaluate(medrecord, index)? { + Some(_) => Box::new(indices.into_iter()), + None => Box::new(std::iter::empty()), + }) + } + + #[inline] + fn evaluate_node_index_comparison_operation<'a>( + medrecord: &MedRecord, + indices: impl Iterator + 'a, + comparison_operand: &NodeIndexComparisonOperand, + kind: &SingleComparisonKind, + ) -> MedRecordResult> { + let comparison_index = get_node_index_comparison_operand!(comparison_operand, medrecord); + + match kind { + SingleComparisonKind::GreaterThan => Ok(Box::new( + indices.filter(move |index| index > &comparison_index), + )), + SingleComparisonKind::GreaterThanOrEqualTo => Ok(Box::new( + indices.filter(move |index| index >= &comparison_index), + )), + SingleComparisonKind::LessThan => Ok(Box::new( + indices.filter(move |index| index < &comparison_index), + )), + SingleComparisonKind::LessThanOrEqualTo => Ok(Box::new( + indices.filter(move |index| index <= &comparison_index), + )), + SingleComparisonKind::EqualTo => Ok(Box::new( + indices.filter(move |index| index == &comparison_index), + )), + SingleComparisonKind::NotEqualTo => Ok(Box::new( + indices.filter(move |index| index != &comparison_index), + )), + SingleComparisonKind::StartsWith => Ok(Box::new( + indices.filter(move |index| index.starts_with(&comparison_index)), + )), + SingleComparisonKind::EndsWith => Ok(Box::new( + indices.filter(move |index| index.ends_with(&comparison_index)), + )), + SingleComparisonKind::Contains => Ok(Box::new( + indices.filter(move |index| index.contains(&comparison_index)), + )), + } + } + + #[inline] + fn evaluate_node_indices_comparison_operation<'a>( + medrecord: &MedRecord, + indices: impl Iterator + 'a, + comparison_operand: &NodeIndicesComparisonOperand, + kind: &MultipleComparisonKind, + ) -> MedRecordResult> { + let comparison_indices = match comparison_operand { + NodeIndicesComparisonOperand::Operand(operand) => { + let context = &operand.context; + + context.evaluate(medrecord)?.cloned().collect::>() + } + NodeIndicesComparisonOperand::Indices(indices) => indices.clone(), + }; + + match kind { + MultipleComparisonKind::IsIn => Ok(Box::new( + indices.filter(move |index| comparison_indices.contains(index)), + )), + MultipleComparisonKind::IsNotIn => Ok(Box::new( + indices.filter(move |index| !comparison_indices.contains(index)), + )), + } + } + + #[inline] + fn evaluate_binary_arithmetic_operation( + medrecord: &MedRecord, + indices: impl Iterator, + operand: &NodeIndexComparisonOperand, + kind: BinaryArithmeticKind, + ) -> MedRecordResult> { + let arithmetic_index = get_node_index_comparison_operand!(operand, medrecord); + + let indices = indices + .map(move |index| { + match kind { + BinaryArithmeticKind::Add => index.add(arithmetic_index.clone()), + BinaryArithmeticKind::Sub => index.sub(arithmetic_index.clone()), + BinaryArithmeticKind::Mul => { + index.clone().mul(arithmetic_index.clone()) + } + BinaryArithmeticKind::Pow => { + index.clone().pow(arithmetic_index.clone()) + } + BinaryArithmeticKind::Mod => { + index.clone().r#mod(arithmetic_index.clone()) + } + } + .map_err(|_| { + MedRecordError::QueryError(format!( + "Failed arithmetic operation {}. Consider narrowing down the indices using .is_string() or .is_int()", + kind, + )) + }) + }); + + // TODO: This is a temporary solution. It should be optimized. + Ok(indices.collect::>>()?.into_iter()) + } + + #[inline] + fn evaluate_unary_arithmetic_operation( + indices: impl Iterator, + kind: UnaryArithmeticKind, + ) -> impl Iterator { + indices.map(move |index| match kind { + UnaryArithmeticKind::Abs => index.abs(), + UnaryArithmeticKind::Trim => index.trim(), + UnaryArithmeticKind::TrimStart => index.trim_start(), + UnaryArithmeticKind::TrimEnd => index.trim_end(), + UnaryArithmeticKind::Lowercase => index.lowercase(), + UnaryArithmeticKind::Uppercase => index.uppercase(), + }) + } + + #[inline] + fn evaluate_slice( + indices: impl Iterator, + range: Range, + ) -> impl Iterator { + indices.map(move |index| index.slice(range.clone())) + } + + #[inline] + fn evaluate_either_or<'a>( + medrecord: &'a MedRecord, + indices: impl Iterator, + either: &Wrapper, + or: &Wrapper, + ) -> MedRecordResult> { + let indices = indices.collect::>(); + + let either_indices = either.evaluate(medrecord, indices.clone().into_iter())?; + let or_indices = or.evaluate(medrecord, indices.into_iter())?; + + Ok(Box::new(either_indices.chain(or_indices).unique())) + } +} + +#[derive(Debug, Clone)] +pub enum NodeIndexOperation { + NodeIndexComparisonOperation { + operand: NodeIndexComparisonOperand, + kind: SingleComparisonKind, + }, + NodeIndicesComparisonOperation { + operand: NodeIndicesComparisonOperand, + kind: MultipleComparisonKind, + }, + BinaryArithmeticOpration { + operand: NodeIndexComparisonOperand, + kind: BinaryArithmeticKind, + }, + UnaryArithmeticOperation { + kind: UnaryArithmeticKind, + }, + + Slice(Range), + + IsString, + IsInt, + + EitherOr { + either: Wrapper, + or: Wrapper, + }, +} + +impl DeepClone for NodeIndexOperation { + fn deep_clone(&self) -> Self { + match self { + Self::NodeIndexComparisonOperation { operand, kind } => { + Self::NodeIndexComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::NodeIndicesComparisonOperation { operand, kind } => { + Self::NodeIndicesComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::BinaryArithmeticOpration { operand, kind } => Self::BinaryArithmeticOpration { + operand: operand.deep_clone(), + kind: kind.clone(), + }, + Self::UnaryArithmeticOperation { kind } => { + Self::UnaryArithmeticOperation { kind: kind.clone() } + } + Self::Slice(range) => Self::Slice(range.clone()), + Self::IsString => Self::IsString, + Self::IsInt => Self::IsInt, + Self::EitherOr { either, or } => Self::EitherOr { + either: either.deep_clone(), + or: or.deep_clone(), + }, + } + } +} + +impl NodeIndexOperation { + pub(crate) fn evaluate( + &self, + medrecord: &MedRecord, + index: NodeIndex, + ) -> MedRecordResult> { + match self { + Self::NodeIndexComparisonOperation { operand, kind } => { + Self::evaluate_node_index_comparison_operation(medrecord, index, operand, kind) + } + Self::NodeIndicesComparisonOperation { operand, kind } => { + Self::evaluate_node_indices_comparison_operation(medrecord, index, operand, kind) + } + Self::BinaryArithmeticOpration { operand, kind } => { + Self::evaluate_binary_arithmetic_operation(medrecord, index, operand, kind) + } + Self::UnaryArithmeticOperation { kind } => Ok(Some(match kind { + UnaryArithmeticKind::Abs => index.abs(), + UnaryArithmeticKind::Trim => index.trim(), + UnaryArithmeticKind::TrimStart => index.trim_start(), + UnaryArithmeticKind::TrimEnd => index.trim_end(), + UnaryArithmeticKind::Lowercase => index.lowercase(), + UnaryArithmeticKind::Uppercase => index.uppercase(), + })), + Self::Slice(range) => Ok(Some(index.slice(range.clone()))), + Self::IsString => Ok(match index { + MedRecordAttribute::String(_) => Some(index), + _ => None, + }), + Self::IsInt => Ok(match index { + MedRecordAttribute::Int(_) => Some(index), + _ => None, + }), + Self::EitherOr { either, or } => Self::evaluate_either_or(medrecord, index, either, or), + } + } + + #[inline] + fn evaluate_node_index_comparison_operation( + medrecord: &MedRecord, + index: NodeIndex, + comparison_operand: &NodeIndexComparisonOperand, + kind: &SingleComparisonKind, + ) -> MedRecordResult> { + let comparison_index = get_node_index_comparison_operand!(comparison_operand, medrecord); + + let comparison_result = match kind { + SingleComparisonKind::GreaterThan => index > comparison_index, + SingleComparisonKind::GreaterThanOrEqualTo => index >= comparison_index, + SingleComparisonKind::LessThan => index < comparison_index, + SingleComparisonKind::LessThanOrEqualTo => index <= comparison_index, + SingleComparisonKind::EqualTo => index == comparison_index, + SingleComparisonKind::NotEqualTo => index != comparison_index, + SingleComparisonKind::StartsWith => index.starts_with(&comparison_index), + SingleComparisonKind::EndsWith => index.ends_with(&comparison_index), + SingleComparisonKind::Contains => index.contains(&comparison_index), + }; + + Ok(if comparison_result { Some(index) } else { None }) + } + + #[inline] + fn evaluate_node_indices_comparison_operation( + medrecord: &MedRecord, + index: NodeIndex, + comparison_operand: &NodeIndicesComparisonOperand, + kind: &MultipleComparisonKind, + ) -> MedRecordResult> { + let comparison_indices = match comparison_operand { + NodeIndicesComparisonOperand::Operand(operand) => { + let context = &operand.context; + + context.evaluate(medrecord)?.cloned().collect::>() + } + NodeIndicesComparisonOperand::Indices(indices) => indices.clone(), + }; + + let comparison_result = match kind { + MultipleComparisonKind::IsIn => comparison_indices + .into_iter() + .any(|comparison_index| index == comparison_index), + MultipleComparisonKind::IsNotIn => comparison_indices + .into_iter() + .all(|comparison_index| index != comparison_index), + }; + + Ok(if comparison_result { Some(index) } else { None }) + } + + #[inline] + fn evaluate_binary_arithmetic_operation( + medrecord: &MedRecord, + index: NodeIndex, + operand: &NodeIndexComparisonOperand, + kind: &BinaryArithmeticKind, + ) -> MedRecordResult> { + let arithmetic_index = get_node_index_comparison_operand!(operand, medrecord); + + Ok(Some(match kind { + BinaryArithmeticKind::Add => index.add(arithmetic_index)?, + BinaryArithmeticKind::Sub => index.sub(arithmetic_index)?, + BinaryArithmeticKind::Mul => index.mul(arithmetic_index)?, + BinaryArithmeticKind::Pow => index.pow(arithmetic_index)?, + BinaryArithmeticKind::Mod => index.r#mod(arithmetic_index)?, + })) + } + + #[inline] + fn evaluate_either_or( + medrecord: &MedRecord, + index: NodeIndex, + either: &Wrapper, + or: &Wrapper, + ) -> MedRecordResult> { + let either_result = either.evaluate(medrecord, index.clone())?; + let or_result = or.evaluate(medrecord, index)?; + + match (either_result, or_result) { + (Some(either_result), _) => Ok(Some(either_result)), + (None, Some(or_result)) => Ok(Some(or_result)), + _ => Ok(None), + } + } +} diff --git a/crates/medmodels-core/src/medrecord/querying/nodes/selection.rs b/crates/medmodels-core/src/medrecord/querying/nodes/selection.rs new file mode 100644 index 00000000..d994543d --- /dev/null +++ b/crates/medmodels-core/src/medrecord/querying/nodes/selection.rs @@ -0,0 +1,35 @@ +use super::NodeOperand; +use crate::{ + errors::MedRecordResult, + medrecord::{querying::wrapper::Wrapper, MedRecord, NodeIndex}, +}; + +#[derive(Debug, Clone)] +pub struct NodeSelection<'a> { + medrecord: &'a MedRecord, + operand: Wrapper, +} + +impl<'a> NodeSelection<'a> { + pub fn new(medrecord: &'a MedRecord, query: Q) -> Self + where + Q: FnOnce(&mut Wrapper), + { + let mut operand = Wrapper::::new(); + + query(&mut operand); + + Self { medrecord, operand } + } + + pub fn iter(self) -> MedRecordResult> { + self.operand.evaluate(self.medrecord) + } + + pub fn collect(self) -> MedRecordResult + where + B: FromIterator<&'a NodeIndex>, + { + Ok(FromIterator::from_iter(self.iter()?)) + } +} diff --git a/crates/medmodels-core/src/medrecord/querying/operation/edge_operation.rs b/crates/medmodels-core/src/medrecord/querying/operation/edge_operation.rs deleted file mode 100644 index f005c53f..00000000 --- a/crates/medmodels-core/src/medrecord/querying/operation/edge_operation.rs +++ /dev/null @@ -1,475 +0,0 @@ -#![allow(clippy::should_implement_trait)] - -use super::{ - operand::{ArithmeticOperation, EdgeIndexInOperand, IntoVecEdgeIndex, ValueOperand}, - AttributeOperation, NodeOperation, Operation, -}; -use crate::medrecord::{ - datatypes::{ - Abs, Ceil, Floor, Lowercase, Mod, Pow, Round, Slice, Sqrt, Trim, TrimEnd, TrimStart, - Uppercase, - }, - EdgeIndex, MedRecord, MedRecordAttribute, -}; - -#[derive(Debug, Clone)] -pub enum EdgeIndexOperation { - Gt(EdgeIndex), - Lt(EdgeIndex), - Gte(EdgeIndex), - Lte(EdgeIndex), - Eq(EdgeIndex), - In(Box), -} - -#[derive(Debug, Clone)] -pub enum EdgeOperation { - Attribute(AttributeOperation), - Index(EdgeIndexOperation), - - ConnectedSource(MedRecordAttribute), - ConnectedTarget(MedRecordAttribute), - InGroup(MedRecordAttribute), - HasAttribute(MedRecordAttribute), - - ConnectedSourceWith(Box), - ConnectedTargetWith(Box), - - HasParallelEdgesWith(Box), - HasParallelEdgesWithSelfComparison(Box), - - And(Box<(EdgeOperation, EdgeOperation)>), - Or(Box<(EdgeOperation, EdgeOperation)>), - Not(Box), -} - -impl Operation for EdgeOperation { - type IndexType = EdgeIndex; - - fn evaluate<'a>( - self, - medrecord: &'a MedRecord, - indices: impl Iterator + 'a, - ) -> Box + 'a> { - match self { - EdgeOperation::Attribute(attribute_operation) => { - Self::evaluate_attribute(indices, attribute_operation, |index| { - medrecord.edge_attributes(index) - }) - } - EdgeOperation::Index(index_operation) => { - Self::evaluate_index(medrecord, indices, index_operation) - } - - EdgeOperation::ConnectedSource(attribute_operand) => Box::new( - Self::evaluate_connected_target(medrecord, indices, attribute_operand), - ), - EdgeOperation::ConnectedTarget(attribute_operand) => Box::new( - Self::evaluate_connected_source(medrecord, indices, attribute_operand), - ), - EdgeOperation::InGroup(attribute_operand) => Box::new(Self::evaluate_in_group( - medrecord, - indices, - attribute_operand, - )), - EdgeOperation::HasAttribute(attribute_operand) => Box::new( - Self::evaluate_has_attribute(indices, attribute_operand, |index| { - medrecord.edge_attributes(index) - }), - ), - - EdgeOperation::ConnectedSourceWith(operation) => Box::new( - Self::evaluate_connected_source_with(medrecord, indices, *operation), - ), - EdgeOperation::ConnectedTargetWith(operation) => Box::new( - Self::evaluate_connected_target_with(medrecord, indices, *operation), - ), - - EdgeOperation::HasParallelEdgesWith(operation) => { - Self::evaluate_has_parallel_edges_with(medrecord, Box::new(indices), *operation) - } - EdgeOperation::HasParallelEdgesWithSelfComparison(operation) => { - Self::evaluate_has_parallel_edges_with_compare_to_self( - medrecord, - Box::new(indices), - *operation, - ) - } - - EdgeOperation::And(operations) => Box::new(Self::evaluate_and( - medrecord, - indices.collect::>(), - (*operations).0, - (*operations).1, - )), - EdgeOperation::Or(operations) => Box::new(Self::evaluate_or( - medrecord, - indices.collect::>(), - (*operations).0, - (*operations).1, - )), - EdgeOperation::Not(operation) => Box::new(Self::evaluate_not( - medrecord, - indices.collect::>(), - *operation, - )), - } - } -} - -impl EdgeOperation { - pub fn and(self, operation: EdgeOperation) -> EdgeOperation { - EdgeOperation::And(Box::new((self, operation))) - } - - pub fn or(self, operation: EdgeOperation) -> EdgeOperation { - EdgeOperation::Or(Box::new((self, operation))) - } - - pub fn xor(self, operation: EdgeOperation) -> EdgeOperation { - EdgeOperation::And(Box::new((self, operation))).not() - } - - pub fn not(self) -> EdgeOperation { - EdgeOperation::Not(Box::new(self)) - } - - fn evaluate_index<'a>( - medrecord: &'a MedRecord, - edge_indices: impl Iterator + 'a, - operation: EdgeIndexOperation, - ) -> Box + 'a> { - match operation { - EdgeIndexOperation::Gt(operand) => { - Box::new(Self::evaluate_index_gt(edge_indices, operand)) - } - EdgeIndexOperation::Lt(operand) => { - Box::new(Self::evaluate_index_lt(edge_indices, operand)) - } - EdgeIndexOperation::Gte(operand) => { - Box::new(Self::evaluate_index_gte(edge_indices, operand)) - } - EdgeIndexOperation::Lte(operand) => { - Box::new(Self::evaluate_index_lte(edge_indices, operand)) - } - EdgeIndexOperation::Eq(operand) => { - Box::new(Self::evaluate_index_eq(edge_indices, operand)) - } - EdgeIndexOperation::In(operands) => Box::new(Self::evaluate_index_in( - edge_indices, - operands.into_vec_edge_index(medrecord), - )), - } - } - - fn evaluate_connected_target<'a>( - medrecord: &'a MedRecord, - edge_indices: impl Iterator, - attribute_operand: MedRecordAttribute, - ) -> impl Iterator { - edge_indices.filter(move |index| { - let Ok(endpoints) = medrecord.edge_endpoints(index) else { - return false; - }; - - *endpoints.1 == attribute_operand - }) - } - - fn evaluate_connected_source<'a>( - medrecord: &'a MedRecord, - edge_indices: impl Iterator, - attribute_operand: MedRecordAttribute, - ) -> impl Iterator { - edge_indices.filter(move |index| { - let Ok(endpoints) = medrecord.edge_endpoints(index) else { - return false; - }; - - *endpoints.0 == attribute_operand - }) - } - - fn evaluate_in_group<'a>( - medrecord: &'a MedRecord, - edge_indices: impl Iterator, - attribute_operand: MedRecordAttribute, - ) -> impl Iterator { - let edges_in_group = match medrecord.edges_in_group(&attribute_operand) { - Ok(edges_in_group) => edges_in_group.collect::>(), - Err(_) => Vec::new(), - }; - - edge_indices.filter(move |index| edges_in_group.contains(index)) - } - - fn evaluate_connected_target_with<'a>( - medrecord: &'a MedRecord, - edge_indices: impl Iterator, - operation: NodeOperation, - ) -> impl Iterator { - edge_indices.filter(move |index| { - let Ok(endpoints) = medrecord.edge_endpoints(index) else { - return false; - }; - - operation - .clone() - .evaluate(medrecord, vec![endpoints.1].into_iter()) - .count() - > 0 - }) - } - - fn evaluate_connected_source_with<'a>( - medrecord: &'a MedRecord, - edge_indices: impl Iterator, - operation: NodeOperation, - ) -> impl Iterator { - edge_indices.filter(move |index| { - let Ok(endpoints) = medrecord.edge_endpoints(index) else { - return false; - }; - - operation - .clone() - .evaluate(medrecord, vec![endpoints.0].into_iter()) - .count() - > 0 - }) - } - - fn evaluate_has_parallel_edges_with<'a>( - medrecord: &'a MedRecord, - edge_indices: Box + 'a>, - operation: EdgeOperation, - ) -> Box + 'a> { - Box::new(edge_indices.filter(move |index| { - let Ok(endpoints) = medrecord.edge_endpoints(index) else { - return false; - }; - - let edges = medrecord - .edges_connecting(vec![endpoints.0], vec![endpoints.1]) - .filter(|other_index| other_index != index); - - operation.clone().evaluate(medrecord, edges).count() > 0 - })) - } - - fn convert_value_operand<'a>( - medrecord: &'a MedRecord, - index: &'a EdgeIndex, - value_operand: ValueOperand, - ) -> Option { - match value_operand { - ValueOperand::Value(value) => Some(ValueOperand::Value(value)), - ValueOperand::Evaluate(attribute) => Some(ValueOperand::Value( - medrecord - .edge_attributes(index) - .ok()? - .get(&attribute)? - .clone(), - )), - ValueOperand::ArithmeticOperation(operation, attribute, other_value) => { - let value = medrecord.edge_attributes(index).ok()?.get(&attribute)?; - - let result = match operation { - ArithmeticOperation::Addition => value.clone() + other_value, - ArithmeticOperation::Subtraction => value.clone() - other_value, - ArithmeticOperation::Multiplication => value.clone() * other_value, - ArithmeticOperation::Division => value.clone() / other_value, - ArithmeticOperation::Power => value.clone().pow(other_value), - ArithmeticOperation::Modulo => value.clone().r#mod(other_value), - } - .ok()?; - - Some(ValueOperand::Value(result)) - } - ValueOperand::Slice(attribute, range) => { - let value = medrecord.edge_attributes(index).ok()?.get(&attribute)?; - - Some(ValueOperand::Value(value.clone().slice(range))) - } - ValueOperand::TransformationOperation(operation, attribute) => { - let value = medrecord.edge_attributes(index).ok()?.get(&attribute)?; - - let result = match operation { - super::operand::TransformationOperation::Round => value.clone().round(), - super::operand::TransformationOperation::Ceil => value.clone().ceil(), - super::operand::TransformationOperation::Floor => value.clone().floor(), - super::operand::TransformationOperation::Abs => value.clone().abs(), - super::operand::TransformationOperation::Sqrt => value.clone().sqrt(), - super::operand::TransformationOperation::Trim => value.clone().trim(), - super::operand::TransformationOperation::TrimStart => { - value.clone().trim_start() - } - super::operand::TransformationOperation::TrimEnd => value.clone().trim_end(), - super::operand::TransformationOperation::Lowercase => value.clone().lowercase(), - super::operand::TransformationOperation::Uppercase => value.clone().uppercase(), - }; - - Some(ValueOperand::Value(result)) - } - } - } - fn evaluate_has_parallel_edges_with_compare_to_self<'a>( - medrecord: &'a MedRecord, - edge_indices: Box + 'a>, - operation: EdgeOperation, - ) -> Box + 'a> { - Box::new(edge_indices.filter(move |index| { - let Ok(endpoints) = medrecord.edge_endpoints(index) else { - return false; - }; - - let edges = medrecord - .edges_connecting(vec![endpoints.0], vec![endpoints.1]) - .filter(|other_index| other_index != index); - - let operation = operation.clone(); - - let EdgeOperation::Attribute(operation) = operation else { - return operation.evaluate(medrecord, edges).count() > 0; - }; - - match operation { - AttributeOperation::Gt(attribute, value) => { - let Some(value) = Self::convert_value_operand(medrecord, index, value) else { - return false; - }; - - Self::evaluate_attribute( - edges, - AttributeOperation::Gt(attribute, value), - |index| medrecord.edge_attributes(index), - ) - .count() - > 0 - } - AttributeOperation::Lt(attribute, value) => { - let Some(value) = Self::convert_value_operand(medrecord, index, value) else { - return false; - }; - - Self::evaluate_attribute( - edges, - AttributeOperation::Lt(attribute, value), - |index| medrecord.edge_attributes(index), - ) - .count() - > 0 - } - AttributeOperation::Gte(attribute, value) => { - let Some(value) = Self::convert_value_operand(medrecord, index, value) else { - return false; - }; - - Self::evaluate_attribute( - edges, - AttributeOperation::Gte(attribute, value), - |index| medrecord.edge_attributes(index), - ) - .count() - > 0 - } - AttributeOperation::Lte(attribute, value) => { - let Some(value) = Self::convert_value_operand(medrecord, index, value) else { - return false; - }; - - Self::evaluate_attribute( - edges, - AttributeOperation::Lte(attribute, value), - |index| medrecord.edge_attributes(index), - ) - .count() - > 0 - } - AttributeOperation::Eq(attribute, value) => { - let Some(value) = Self::convert_value_operand(medrecord, index, value) else { - return false; - }; - - Self::evaluate_attribute( - edges, - AttributeOperation::Eq(attribute, value), - |index| medrecord.edge_attributes(index), - ) - .count() - > 0 - } - AttributeOperation::Neq(attribute, value) => { - let Some(value) = Self::convert_value_operand(medrecord, index, value) else { - return false; - }; - - Self::evaluate_attribute( - edges, - AttributeOperation::Neq(attribute, value), - |index| medrecord.edge_attributes(index), - ) - .count() - > 0 - } - AttributeOperation::In(attribute, value) => { - Self::evaluate_attribute( - edges, - AttributeOperation::In(attribute, value), - |index| medrecord.edge_attributes(index), - ) - .count() - > 0 - } - AttributeOperation::NotIn(attribute, value) => { - Self::evaluate_attribute( - edges, - AttributeOperation::In(attribute, value), - |index| medrecord.edge_attributes(index), - ) - .count() - > 0 - } - AttributeOperation::StartsWith(attribute, value) => { - let Some(value) = Self::convert_value_operand(medrecord, index, value) else { - return false; - }; - - Self::evaluate_attribute( - edges, - AttributeOperation::StartsWith(attribute, value), - |index| medrecord.edge_attributes(index), - ) - .count() - > 0 - } - AttributeOperation::EndsWith(attribute, value) => { - let Some(value) = Self::convert_value_operand(medrecord, index, value) else { - return false; - }; - - Self::evaluate_attribute( - edges, - AttributeOperation::EndsWith(attribute, value), - |index| medrecord.edge_attributes(index), - ) - .count() - > 0 - } - AttributeOperation::Contains(attribute, value) => { - let Some(value) = Self::convert_value_operand(medrecord, index, value) else { - return false; - }; - - Self::evaluate_attribute( - edges, - AttributeOperation::Contains(attribute, value), - |index| medrecord.edge_attributes(index), - ) - .count() - > 0 - } - } - })) - } -} diff --git a/crates/medmodels-core/src/medrecord/querying/operation/mod.rs b/crates/medmodels-core/src/medrecord/querying/operation/mod.rs deleted file mode 100644 index 174adeda..00000000 --- a/crates/medmodels-core/src/medrecord/querying/operation/mod.rs +++ /dev/null @@ -1,394 +0,0 @@ -mod edge_operation; -mod node_operation; -mod operand; - -pub use self::{ - edge_operation::EdgeOperation, - node_operation::NodeOperation, - operand::{ - edge, node, ArithmeticOperation, EdgeAttributeOperand, EdgeIndexOperand, EdgeOperand, - NodeAttributeOperand, NodeIndexOperand, NodeOperand, TransformationOperation, ValueOperand, - }, -}; -use crate::{ - errors::MedRecordError, - medrecord::{ - datatypes::{ - Abs, Ceil, Contains, EndsWith, Floor, Lowercase, Mod, PartialNeq, Pow, Round, Slice, - Sqrt, StartsWith, Trim, TrimEnd, TrimStart, Uppercase, - }, - Attributes, MedRecord, MedRecordAttribute, MedRecordValue, - }, -}; - -macro_rules! implement_attribute_evaluate { - ($name: ident, $evaluate: ident) => { - fn $name<'a, P>( - node_indices: impl Iterator, - attribute_operand: MedRecordAttribute, - value_operand: ValueOperand, - attributes_for_index_fn: P, - ) -> impl Iterator - where - P: Fn(&Self::IndexType) -> Result<&'a Attributes, MedRecordError>, - Self::IndexType: 'a, - { - node_indices.filter(move |index| { - let Ok(attributes) = attributes_for_index_fn(index) else { - return false; - }; - - let Some(value) = attributes.get(&attribute_operand) else { - return false; - }; - - match &value_operand { - ValueOperand::Value(value_operand) => value.$evaluate(value_operand), - ValueOperand::Evaluate(value_attribute) => { - let Some(other) = attributes.get(&value_attribute) else { - return false; - }; - - value.$evaluate(other) - } - ValueOperand::ArithmeticOperation( - operation, - value_attribute, - value_operand, - ) => { - let Some(other) = attributes.get(&value_attribute) else { - return false; - }; - - let operation = match operation { - ArithmeticOperation::Addition => other.clone() + value_operand.clone(), - ArithmeticOperation::Subtraction => { - other.clone() - value_operand.clone() - } - ArithmeticOperation::Multiplication => { - other.clone() * value_operand.clone() - } - ArithmeticOperation::Division => other.clone() / value_operand.clone(), - ArithmeticOperation::Power => other.clone().pow(value_operand.clone()), - ArithmeticOperation::Modulo => { - other.clone().r#mod(value_operand.clone()) - } - }; - - match operation { - Ok(operation) => value.$evaluate(&operation), - Err(_) => false, - } - } - ValueOperand::TransformationOperation(operation, value_attribute) => { - let Some(other) = attributes.get(&value_attribute) else { - return false; - }; - - let operation = match operation { - TransformationOperation::Round => other.clone().round(), - TransformationOperation::Ceil => other.clone().ceil(), - TransformationOperation::Floor => other.clone().floor(), - TransformationOperation::Abs => other.clone().abs(), - TransformationOperation::Sqrt => other.clone().sqrt(), - TransformationOperation::Trim => other.clone().trim(), - TransformationOperation::TrimStart => other.clone().trim_start(), - TransformationOperation::TrimEnd => other.clone().trim_end(), - TransformationOperation::Lowercase => other.clone().lowercase(), - TransformationOperation::Uppercase => other.clone().uppercase(), - }; - - value.$evaluate(&operation) - } - ValueOperand::Slice(value_attribute, range) => { - let Some(other) = attributes.get(&value_attribute) else { - return false; - }; - - value.$evaluate(&other.clone().slice(range.clone())) - } - } - }) - } - }; -} - -macro_rules! implement_index_evaluate { - ($name: ident, $evaluate: ident) => { - fn $name<'a>( - indices: impl Iterator, - operand: Self::IndexType, - ) -> impl Iterator - where - Self::IndexType: 'a, - { - indices.filter(move |index| (*index).$evaluate(&operand)) - } - }; -} - -pub(super) trait Operation: Sized { - type IndexType: PartialEq + PartialNeq + PartialOrd; - - fn evaluate_and<'a>( - medrecord: &'a MedRecord, - indices: Vec<&'a Self::IndexType>, - operation1: Self, - operation2: Self, - ) -> impl Iterator { - let operation1_indices = operation1 - .evaluate(medrecord, indices.clone().into_iter()) - .collect::>(); - let operation2_indices = operation2 - .evaluate(medrecord, indices.clone().into_iter()) - .collect::>(); - - indices.into_iter().filter(move |index| { - operation1_indices.contains(index) && operation2_indices.contains(index) - }) - } - - fn evaluate_or<'a>( - medrecord: &'a MedRecord, - indices: Vec<&'a Self::IndexType>, - operation1: Self, - operation2: Self, - ) -> impl Iterator { - let operation1_indices = operation1 - .evaluate(medrecord, indices.clone().into_iter()) - .collect::>(); - let operation2_indices = operation2 - .evaluate(medrecord, indices.clone().into_iter()) - .collect::>(); - - indices.into_iter().filter(move |index| { - operation1_indices.contains(index) || operation2_indices.contains(index) - }) - } - - fn evaluate_not<'a>( - medrecord: &'a MedRecord, - indices: Vec<&'a Self::IndexType>, - operation: Self, - ) -> impl Iterator { - let operation_indices = operation - .evaluate(medrecord, indices.clone().into_iter()) - .collect::>(); - - indices - .into_iter() - .filter(move |index| !operation_indices.contains(index)) - } - - fn evaluate_attribute_in<'a, P>( - node_indices: impl Iterator, - attribute_operand: MedRecordAttribute, - value_operands: Vec, - attributes_for_index_fn: P, - ) -> impl Iterator - where - P: Fn(&Self::IndexType) -> Result<&'a Attributes, MedRecordError>, - Self::IndexType: 'a, - { - node_indices.filter(move |index| { - let Ok(attributes) = attributes_for_index_fn(index) else { - return false; - }; - - let Some(value) = attributes.get(&attribute_operand) else { - return false; - }; - - value_operands.contains(value) - }) - } - - fn evaluate_attribute_not_in<'a, P>( - node_indices: impl Iterator, - attribute_operand: MedRecordAttribute, - value_operands: Vec, - attributes_for_index_fn: P, - ) -> impl Iterator - where - P: Fn(&Self::IndexType) -> Result<&'a Attributes, MedRecordError>, - Self::IndexType: 'a, - { - node_indices.filter(move |index| { - let Ok(attributes) = attributes_for_index_fn(index) else { - return false; - }; - - let Some(value) = attributes.get(&attribute_operand) else { - return false; - }; - - !value_operands.contains(value) - }) - } - - implement_attribute_evaluate!(evaluate_attribute_gt, gt); - implement_attribute_evaluate!(evaluate_attribute_lt, lt); - implement_attribute_evaluate!(evaluate_attribute_gte, ge); - implement_attribute_evaluate!(evaluate_attribute_lte, le); - implement_attribute_evaluate!(evaluate_attribute_eq, eq); - implement_attribute_evaluate!(evaluate_attribute_neq, neq); - implement_attribute_evaluate!(evaluate_attribute_starts_with, starts_with); - implement_attribute_evaluate!(evaluate_attribute_ends_with, ends_with); - implement_attribute_evaluate!(evaluate_attribute_contains, contains); - - fn evaluate_has_attribute<'a, P>( - indices: impl Iterator, - attribute_operand: MedRecordAttribute, - attributes_for_index_fn: P, - ) -> impl Iterator - where - P: Fn(&Self::IndexType) -> Result<&'a Attributes, MedRecordError>, - Self::IndexType: 'a, - { - indices.filter(move |index| { - let Ok(attributes) = attributes_for_index_fn(index) else { - return false; - }; - - attributes.contains_key(&attribute_operand) - }) - } - - fn evaluate_attribute<'a, P>( - indices: impl Iterator + 'a, - operation: AttributeOperation, - attributes_for_index_fn: P, - ) -> Box + 'a> - where - P: Fn(&Self::IndexType) -> Result<&'a Attributes, MedRecordError> + 'a, - Self: 'a, - { - match operation { - AttributeOperation::Gt(attribute_operand, value_operand) => { - Box::new(Self::evaluate_attribute_gt( - indices, - attribute_operand, - value_operand, - attributes_for_index_fn, - )) - } - AttributeOperation::Lt(attribute_operand, value_operand) => { - Box::new(Self::evaluate_attribute_lt( - indices, - attribute_operand, - value_operand, - attributes_for_index_fn, - )) - } - AttributeOperation::Gte(attribute_operand, value_operand) => { - Box::new(Self::evaluate_attribute_gte( - indices, - attribute_operand, - value_operand, - attributes_for_index_fn, - )) - } - AttributeOperation::Lte(attribute_operand, value_operand) => { - Box::new(Self::evaluate_attribute_lte( - indices, - attribute_operand, - value_operand, - attributes_for_index_fn, - )) - } - AttributeOperation::Eq(attribute_operand, value_operand) => { - Box::new(Self::evaluate_attribute_eq( - indices, - attribute_operand, - value_operand, - attributes_for_index_fn, - )) - } - AttributeOperation::Neq(attribute_operand, value_operand) => { - Box::new(Self::evaluate_attribute_neq( - indices, - attribute_operand, - value_operand, - attributes_for_index_fn, - )) - } - AttributeOperation::In(attribute_operand, value_operands) => { - Box::new(Self::evaluate_attribute_in( - indices, - attribute_operand, - value_operands, - attributes_for_index_fn, - )) - } - AttributeOperation::NotIn(attribute_operand, value_operands) => { - Box::new(Self::evaluate_attribute_not_in( - indices, - attribute_operand, - value_operands, - attributes_for_index_fn, - )) - } - AttributeOperation::StartsWith(attribute_operand, value_operand) => { - Box::new(Self::evaluate_attribute_starts_with( - indices, - attribute_operand, - value_operand, - attributes_for_index_fn, - )) - } - AttributeOperation::EndsWith(attribute_operand, value_operand) => { - Box::new(Self::evaluate_attribute_ends_with( - indices, - attribute_operand, - value_operand, - attributes_for_index_fn, - )) - } - AttributeOperation::Contains(attribute_operand, value_operand) => { - Box::new(Self::evaluate_attribute_contains( - indices, - attribute_operand, - value_operand, - attributes_for_index_fn, - )) - } - } - } - - implement_index_evaluate!(evaluate_index_gt, gt); - implement_index_evaluate!(evaluate_index_lt, lt); - implement_index_evaluate!(evaluate_index_gte, ge); - implement_index_evaluate!(evaluate_index_lte, le); - implement_index_evaluate!(evaluate_index_eq, eq); - - fn evaluate_index_in<'a>( - indices: impl Iterator, - operands: Vec, - ) -> impl Iterator - where - Self::IndexType: 'a, - { - indices.filter(move |index| operands.contains(index)) - } - - fn evaluate<'a>( - self, - medrecord: &'a MedRecord, - indices: impl Iterator + 'a, - ) -> Box + 'a>; -} - -#[derive(Debug, Clone)] -pub enum AttributeOperation { - Gt(MedRecordAttribute, ValueOperand), - Lt(MedRecordAttribute, ValueOperand), - Gte(MedRecordAttribute, ValueOperand), - Lte(MedRecordAttribute, ValueOperand), - Eq(MedRecordAttribute, ValueOperand), - Neq(MedRecordAttribute, ValueOperand), - In(MedRecordAttribute, Vec), - NotIn(MedRecordAttribute, Vec), - StartsWith(MedRecordAttribute, ValueOperand), - EndsWith(MedRecordAttribute, ValueOperand), - Contains(MedRecordAttribute, ValueOperand), -} diff --git a/crates/medmodels-core/src/medrecord/querying/operation/node_operation.rs b/crates/medmodels-core/src/medrecord/querying/operation/node_operation.rs deleted file mode 100644 index 677db205..00000000 --- a/crates/medmodels-core/src/medrecord/querying/operation/node_operation.rs +++ /dev/null @@ -1,246 +0,0 @@ -#![allow(clippy::should_implement_trait)] - -use super::{ - edge_operation::EdgeOperation, - operand::{IntoVecNodeIndex, NodeIndexInOperand}, - AttributeOperation, Operation, -}; -use crate::medrecord::{ - datatypes::{Contains, EndsWith, StartsWith}, - MedRecord, MedRecordAttribute, NodeIndex, -}; - -macro_rules! implement_index_evaluate { - ($name: ident, $evaluate: ident) => { - fn $name<'a>( - indices: impl Iterator, - operand: NodeIndex, - ) -> impl Iterator { - indices.filter(move |index| (*index).$evaluate(&operand)) - } - }; -} - -#[derive(Debug, Clone)] -pub enum NodeIndexOperation { - Gt(NodeIndex), - Lt(NodeIndex), - Gte(NodeIndex), - Lte(NodeIndex), - Eq(NodeIndex), - In(Box), - StartsWith(NodeIndex), - EndsWith(NodeIndex), - Contains(NodeIndex), -} - -#[derive(Debug, Clone)] -pub enum NodeOperation { - Attribute(AttributeOperation), - Index(NodeIndexOperation), - - InGroup(MedRecordAttribute), - HasAttribute(MedRecordAttribute), - - HasIncomingEdgeWith(Box), - HasOutgoingEdgeWith(Box), - HasNeighborWith(Box), - HasNeighborUndirectedWith(Box), - - And(Box<(NodeOperation, NodeOperation)>), - Or(Box<(NodeOperation, NodeOperation)>), - Not(Box), -} - -impl Operation for NodeOperation { - type IndexType = NodeIndex; - - fn evaluate<'a>( - self, - medrecord: &'a MedRecord, - indices: impl Iterator + 'a, - ) -> Box + 'a> { - match self { - NodeOperation::Attribute(attribute_operation) => { - Self::evaluate_attribute(indices, attribute_operation, |index| { - medrecord.node_attributes(index) - }) - } - NodeOperation::Index(index_operation) => { - Self::evaluate_index(medrecord, indices, index_operation) - } - - NodeOperation::InGroup(attribute_operand) => Box::new(Self::evaluate_in_group( - medrecord, - indices, - attribute_operand, - )), - NodeOperation::HasAttribute(attribute_operand) => Box::new( - Self::evaluate_has_attribute(indices, attribute_operand, |index| { - medrecord.node_attributes(index) - }), - ), - - NodeOperation::HasOutgoingEdgeWith(operation) => Box::new( - Self::evaluate_has_outgoing_edge_with(medrecord, indices, *operation), - ), - NodeOperation::HasIncomingEdgeWith(operation) => Box::new( - Self::evaluate_has_incoming_edge_with(medrecord, indices, *operation), - ), - NodeOperation::HasNeighborWith(operation) => Box::new( - Self::evaluate_has_neighbor_with(medrecord, indices, *operation), - ), - NodeOperation::HasNeighborUndirectedWith(operation) => Box::new( - Self::evaluate_has_neighbor_undirected_with(medrecord, indices, *operation), - ), - - NodeOperation::And(operations) => Box::new(Self::evaluate_and( - medrecord, - indices.collect::>(), - (*operations).0, - (*operations).1, - )), - NodeOperation::Or(operations) => Box::new(Self::evaluate_or( - medrecord, - indices.collect::>(), - (*operations).0, - (*operations).1, - )), - NodeOperation::Not(operation) => Box::new(Self::evaluate_not( - medrecord, - indices.collect::>(), - *operation, - )), - } - } -} - -impl NodeOperation { - pub fn and(self, operation: NodeOperation) -> NodeOperation { - NodeOperation::And(Box::new((self, operation))) - } - - pub fn or(self, operation: NodeOperation) -> NodeOperation { - NodeOperation::Or(Box::new((self, operation))) - } - - pub fn xor(self, operation: NodeOperation) -> NodeOperation { - NodeOperation::And(Box::new((self, operation))).not() - } - - pub fn not(self) -> NodeOperation { - NodeOperation::Not(Box::new(self)) - } - - fn evaluate_index<'a>( - medrecord: &'a MedRecord, - node_indices: impl Iterator + 'a, - operation: NodeIndexOperation, - ) -> Box + 'a> { - match operation { - NodeIndexOperation::Gt(operand) => { - Box::new(Self::evaluate_index_gt(node_indices, operand)) - } - NodeIndexOperation::Lt(operand) => { - Box::new(Self::evaluate_index_lt(node_indices, operand)) - } - NodeIndexOperation::Gte(operand) => { - Box::new(Self::evaluate_index_gte(node_indices, operand)) - } - NodeIndexOperation::Lte(operand) => { - Box::new(Self::evaluate_index_lte(node_indices, operand)) - } - NodeIndexOperation::Eq(operand) => { - Box::new(Self::evaluate_index_eq(node_indices, operand)) - } - NodeIndexOperation::In(operands) => Box::new(Self::evaluate_index_in( - node_indices, - operands.into_vec_node_index(medrecord), - )), - NodeIndexOperation::StartsWith(operand) => { - Box::new(Self::evaluate_index_starts_with(node_indices, operand)) - } - NodeIndexOperation::EndsWith(operand) => { - Box::new(Self::evaluate_index_ends_with(node_indices, operand)) - } - NodeIndexOperation::Contains(operand) => { - Box::new(Self::evaluate_index_contains(node_indices, operand)) - } - } - } - - implement_index_evaluate!(evaluate_index_starts_with, starts_with); - implement_index_evaluate!(evaluate_index_ends_with, ends_with); - implement_index_evaluate!(evaluate_index_contains, contains); - - fn evaluate_in_group<'a>( - medrecord: &'a MedRecord, - node_indices: impl Iterator, - attribute_operand: MedRecordAttribute, - ) -> impl Iterator { - let nodes_in_group = match medrecord.nodes_in_group(&attribute_operand) { - Ok(nodes_in_group) => nodes_in_group.collect::>(), - Err(_) => Vec::new(), - }; - - node_indices.filter(move |index| nodes_in_group.contains(index)) - } - - fn evaluate_has_outgoing_edge_with<'a>( - medrecord: &'a MedRecord, - node_indices: impl Iterator, - operation: EdgeOperation, - ) -> impl Iterator { - node_indices.filter(move |index| { - let Ok(edges) = medrecord.outgoing_edges(index) else { - return false; - }; - - let edge_indices = operation.clone().evaluate(medrecord, edges); - - edge_indices.count() > 0 - }) - } - - fn evaluate_has_incoming_edge_with<'a>( - medrecord: &'a MedRecord, - node_indices: impl Iterator, - operation: EdgeOperation, - ) -> impl Iterator { - node_indices.filter(move |index| { - let Ok(edges) = medrecord.incoming_edges(index) else { - return false; - }; - - operation.clone().evaluate(medrecord, edges).count() > 0 - }) - } - - fn evaluate_has_neighbor_with<'a>( - medrecord: &'a MedRecord, - node_indices: impl Iterator, - operation: NodeOperation, - ) -> impl Iterator { - node_indices.filter(move |index| { - let Ok(neighbors) = medrecord.neighbors(index) else { - return false; - }; - - operation.clone().evaluate(medrecord, neighbors).count() > 0 - }) - } - - fn evaluate_has_neighbor_undirected_with<'a>( - medrecord: &'a MedRecord, - node_indices: impl Iterator, - operation: NodeOperation, - ) -> impl Iterator { - node_indices.filter(move |index| { - let Ok(neighbors) = medrecord.neighbors_undirected(index) else { - return false; - }; - - operation.clone().evaluate(medrecord, neighbors).count() > 0 - }) - } -} diff --git a/crates/medmodels-core/src/medrecord/querying/operation/operand.rs b/crates/medmodels-core/src/medrecord/querying/operation/operand.rs deleted file mode 100644 index c7b7849e..00000000 --- a/crates/medmodels-core/src/medrecord/querying/operation/operand.rs +++ /dev/null @@ -1,649 +0,0 @@ -#![allow(clippy::should_implement_trait)] - -use super::{ - edge_operation::EdgeIndexOperation, - node_operation::{NodeIndexOperation, NodeOperation}, - AttributeOperation, EdgeOperation, Operation, -}; -use crate::medrecord::{ - EdgeIndex, Group, MedRecord, MedRecordAttribute, MedRecordValue, NodeIndex, -}; -use std::{fmt::Debug, ops::Range}; - -#[derive(Debug, Clone)] -pub enum ArithmeticOperation { - Addition, - Subtraction, - Multiplication, - Division, - Power, - Modulo, -} - -#[derive(Debug, Clone)] -pub enum TransformationOperation { - Round, - Ceil, - Floor, - Abs, - Sqrt, - - Trim, - TrimStart, - TrimEnd, - - Lowercase, - Uppercase, -} - -#[derive(Debug, Clone)] -pub enum ValueOperand { - Value(MedRecordValue), - Evaluate(MedRecordAttribute), - ArithmeticOperation(ArithmeticOperation, MedRecordAttribute, MedRecordValue), - TransformationOperation(TransformationOperation, MedRecordAttribute), - Slice(MedRecordAttribute, Range), -} - -pub trait IntoValueOperand { - fn into_value_operand(self) -> ValueOperand; -} - -impl> IntoValueOperand for T { - fn into_value_operand(self) -> ValueOperand { - ValueOperand::Value(self.into()) - } -} -impl IntoValueOperand for NodeAttributeOperand { - fn into_value_operand(self) -> ValueOperand { - ValueOperand::Evaluate(self.into()) - } -} -impl IntoValueOperand for EdgeAttributeOperand { - fn into_value_operand(self) -> ValueOperand { - ValueOperand::Evaluate(self.into()) - } -} -impl IntoValueOperand for ValueOperand { - fn into_value_operand(self) -> ValueOperand { - self - } -} - -#[derive(Debug, Clone)] -pub struct NodeAttributeOperand(MedRecordAttribute); - -impl From for NodeAttributeOperand { - fn from(value: MedRecordAttribute) -> Self { - NodeAttributeOperand(value) - } -} - -impl From for MedRecordAttribute { - fn from(val: NodeAttributeOperand) -> Self { - val.0 - } -} - -impl NodeAttributeOperand { - pub fn greater(self, operand: impl IntoValueOperand) -> NodeOperation { - NodeOperation::Attribute(AttributeOperation::Gt( - self.into(), - operand.into_value_operand(), - )) - } - pub fn less(self, operand: impl IntoValueOperand) -> NodeOperation { - NodeOperation::Attribute(AttributeOperation::Lt( - self.into(), - operand.into_value_operand(), - )) - } - pub fn greater_or_equal(self, operand: impl IntoValueOperand) -> NodeOperation { - NodeOperation::Attribute(AttributeOperation::Gte( - self.into(), - operand.into_value_operand(), - )) - } - pub fn less_or_equal(self, operand: impl IntoValueOperand) -> NodeOperation { - NodeOperation::Attribute(AttributeOperation::Lte( - self.into(), - operand.into_value_operand(), - )) - } - - pub fn equal(self, operand: impl IntoValueOperand) -> NodeOperation { - NodeOperation::Attribute(AttributeOperation::Eq( - self.into(), - operand.into_value_operand(), - )) - } - pub fn not_equal(self, operand: impl IntoValueOperand) -> NodeOperation { - NodeOperation::Attribute(AttributeOperation::Neq( - self.into(), - operand.into_value_operand(), - )) - } - - pub fn r#in(self, operand: Vec>) -> NodeOperation { - NodeOperation::Attribute(AttributeOperation::In( - self.into(), - operand.into_iter().map(|value| value.into()).collect(), - )) - } - pub fn not_in(self, operand: Vec>) -> NodeOperation { - NodeOperation::Attribute(AttributeOperation::NotIn( - self.into(), - operand.into_iter().map(|value| value.into()).collect(), - )) - } - - pub fn starts_with(self, operand: impl IntoValueOperand) -> NodeOperation { - NodeOperation::Attribute(AttributeOperation::StartsWith( - self.into(), - operand.into_value_operand(), - )) - } - - pub fn ends_with(self, operand: impl IntoValueOperand) -> NodeOperation { - NodeOperation::Attribute(AttributeOperation::EndsWith( - self.into(), - operand.into_value_operand(), - )) - } - - pub fn contains(self, operand: impl IntoValueOperand) -> NodeOperation { - NodeOperation::Attribute(AttributeOperation::Contains( - self.into(), - operand.into_value_operand(), - )) - } - - pub fn add(self, value: impl Into) -> ValueOperand { - ValueOperand::ArithmeticOperation(ArithmeticOperation::Addition, self.into(), value.into()) - } - - pub fn sub(self, value: impl Into) -> ValueOperand { - ValueOperand::ArithmeticOperation( - ArithmeticOperation::Subtraction, - self.into(), - value.into(), - ) - } - - pub fn mul(self, value: impl Into) -> ValueOperand { - ValueOperand::ArithmeticOperation( - ArithmeticOperation::Multiplication, - self.into(), - value.into(), - ) - } - - pub fn div(self, value: impl Into) -> ValueOperand { - ValueOperand::ArithmeticOperation(ArithmeticOperation::Division, self.into(), value.into()) - } - - pub fn pow(self, value: impl Into) -> ValueOperand { - ValueOperand::ArithmeticOperation(ArithmeticOperation::Power, self.into(), value.into()) - } - - pub fn r#mod(self, value: impl Into) -> ValueOperand { - ValueOperand::ArithmeticOperation(ArithmeticOperation::Modulo, self.into(), value.into()) - } - - pub fn round(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::Round, self.into()) - } - - pub fn ceil(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::Ceil, self.into()) - } - - pub fn floor(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::Floor, self.into()) - } - - pub fn abs(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::Abs, self.into()) - } - - pub fn sqrt(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::Sqrt, self.into()) - } - - pub fn trim(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::Trim, self.into()) - } - - pub fn trim_start(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::TrimStart, self.into()) - } - - pub fn trim_end(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::TrimEnd, self.into()) - } - - pub fn lowercase(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::Lowercase, self.into()) - } - - pub fn uppercase(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::Uppercase, self.into()) - } - - pub fn slice(self, range: Range) -> ValueOperand { - ValueOperand::Slice(self.into(), range) - } -} - -#[derive(Debug, Clone)] -pub struct EdgeAttributeOperand(MedRecordAttribute); - -impl From for MedRecordAttribute { - fn from(val: EdgeAttributeOperand) -> Self { - val.0 - } -} - -impl EdgeAttributeOperand { - pub fn greater(self, operand: impl IntoValueOperand) -> EdgeOperation { - EdgeOperation::Attribute(AttributeOperation::Gt( - self.into(), - operand.into_value_operand(), - )) - } - pub fn less(self, operand: impl IntoValueOperand) -> EdgeOperation { - EdgeOperation::Attribute(AttributeOperation::Lt( - self.into(), - operand.into_value_operand(), - )) - } - pub fn greater_or_equal(self, operand: impl IntoValueOperand) -> EdgeOperation { - EdgeOperation::Attribute(AttributeOperation::Gte( - self.into(), - operand.into_value_operand(), - )) - } - pub fn less_or_equal(self, operand: impl IntoValueOperand) -> EdgeOperation { - EdgeOperation::Attribute(AttributeOperation::Lte( - self.into(), - operand.into_value_operand(), - )) - } - - pub fn equal(self, operand: impl IntoValueOperand) -> EdgeOperation { - EdgeOperation::Attribute(AttributeOperation::Eq( - self.into(), - operand.into_value_operand(), - )) - } - pub fn not_equal(self, operand: impl IntoValueOperand) -> EdgeOperation { - EdgeOperation::Attribute(AttributeOperation::Neq( - self.into(), - operand.into_value_operand(), - )) - } - - pub fn r#in(self, operand: Vec>) -> EdgeOperation { - EdgeOperation::Attribute(AttributeOperation::In( - self.into(), - operand.into_iter().map(|value| value.into()).collect(), - )) - } - pub fn not_in(self, operand: Vec>) -> EdgeOperation { - EdgeOperation::Attribute(AttributeOperation::NotIn( - self.into(), - operand.into_iter().map(|value| value.into()).collect(), - )) - } - - pub fn starts_with(self, operand: impl IntoValueOperand) -> EdgeOperation { - EdgeOperation::Attribute(AttributeOperation::StartsWith( - self.into(), - operand.into_value_operand(), - )) - } - - pub fn ends_with(self, operand: impl IntoValueOperand) -> EdgeOperation { - EdgeOperation::Attribute(AttributeOperation::EndsWith( - self.into(), - operand.into_value_operand(), - )) - } - - pub fn contains(self, operand: impl IntoValueOperand) -> EdgeOperation { - EdgeOperation::Attribute(AttributeOperation::Contains( - self.into(), - operand.into_value_operand(), - )) - } - - pub fn add(self, value: impl Into) -> ValueOperand { - ValueOperand::ArithmeticOperation(ArithmeticOperation::Addition, self.into(), value.into()) - } - - pub fn sub(self, value: impl Into) -> ValueOperand { - ValueOperand::ArithmeticOperation( - ArithmeticOperation::Subtraction, - self.into(), - value.into(), - ) - } - - pub fn mul(self, value: impl Into) -> ValueOperand { - ValueOperand::ArithmeticOperation( - ArithmeticOperation::Multiplication, - self.into(), - value.into(), - ) - } - - pub fn div(self, value: impl Into) -> ValueOperand { - ValueOperand::ArithmeticOperation(ArithmeticOperation::Division, self.into(), value.into()) - } - - pub fn pow(self, value: impl Into) -> ValueOperand { - ValueOperand::ArithmeticOperation(ArithmeticOperation::Power, self.into(), value.into()) - } - - pub fn r#mod(self, value: impl Into) -> ValueOperand { - ValueOperand::ArithmeticOperation(ArithmeticOperation::Modulo, self.into(), value.into()) - } - - pub fn round(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::Round, self.into()) - } - - pub fn ceil(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::Ceil, self.into()) - } - - pub fn floor(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::Floor, self.into()) - } - - pub fn abs(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::Abs, self.into()) - } - - pub fn sqrt(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::Sqrt, self.into()) - } - - pub fn trim(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::Trim, self.into()) - } - - pub fn trim_start(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::TrimStart, self.into()) - } - - pub fn trim_end(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::TrimEnd, self.into()) - } - - pub fn lowercase(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::Lowercase, self.into()) - } - - pub fn uppercase(self) -> ValueOperand { - ValueOperand::TransformationOperation(TransformationOperation::Uppercase, self.into()) - } - - pub fn slice(self, range: Range) -> ValueOperand { - ValueOperand::Slice(self.into(), range) - } -} - -#[derive(Debug, Clone)] -pub enum NodeIndexInOperand { - Vector(Vec), - Operation(NodeOperation), -} - -impl From> for NodeIndexInOperand -where - T: Into, -{ - fn from(value: Vec) -> NodeIndexInOperand { - NodeIndexInOperand::Vector(value.into_iter().map(|value| value.into()).collect()) - } -} - -impl From for NodeIndexInOperand { - fn from(value: NodeOperation) -> Self { - NodeIndexInOperand::Operation(value) - } -} - -pub(super) trait IntoVecNodeIndex { - fn into_vec_node_index(self, medrecord: &MedRecord) -> Vec; -} - -impl IntoVecNodeIndex for NodeIndexInOperand { - fn into_vec_node_index(self, medrecord: &MedRecord) -> Vec { - match self { - NodeIndexInOperand::Vector(value) => value, - NodeIndexInOperand::Operation(operation) => operation - .evaluate(medrecord, medrecord.node_indices()) - .cloned() - .collect(), - } - } -} - -#[derive(Debug, Clone)] -pub struct NodeIndexOperand; - -impl NodeIndexOperand { - pub fn greater(self, operand: impl Into) -> NodeOperation { - NodeOperation::Index(NodeIndexOperation::Gt(operand.into())) - } - pub fn less(self, operand: impl Into) -> NodeOperation { - NodeOperation::Index(NodeIndexOperation::Lt(operand.into())) - } - pub fn greater_or_equal(self, operand: impl Into) -> NodeOperation { - NodeOperation::Index(NodeIndexOperation::Gte(operand.into())) - } - pub fn less_or_equal(self, operand: impl Into) -> NodeOperation { - NodeOperation::Index(NodeIndexOperation::Lte(operand.into())) - } - - pub fn equal(self, operand: impl Into) -> NodeOperation { - NodeOperation::Index(NodeIndexOperation::Eq(operand.into())) - } - pub fn not_equal(self, operand: impl Into) -> NodeOperation { - self.equal(operand).not() - } - - pub fn r#in(self, operand: impl Into) -> NodeOperation { - NodeOperation::Index(NodeIndexOperation::In(Box::new(operand.into()))) - } - pub fn not_in(self, operand: impl Into) -> NodeOperation { - self.r#in(operand).not() - } - - pub fn starts_with(self, operand: impl Into) -> NodeOperation { - NodeOperation::Index(NodeIndexOperation::StartsWith(operand.into())) - } - - pub fn ends_with(self, operand: impl Into) -> NodeOperation { - NodeOperation::Index(NodeIndexOperation::EndsWith(operand.into())) - } - - pub fn contains(self, operand: impl Into) -> NodeOperation { - NodeOperation::Index(NodeIndexOperation::Contains(operand.into())) - } -} - -#[derive(Debug, Clone)] -pub struct NodeOperand; - -impl NodeOperand { - pub fn in_group(self, operand: impl Into) -> NodeOperation { - NodeOperation::InGroup(operand.into()) - } - - pub fn has_attribute(self, operand: impl Into) -> NodeOperation { - NodeOperation::HasAttribute(operand.into()) - } - - pub fn has_outgoing_edge_with(self, operation: EdgeOperation) -> NodeOperation { - NodeOperation::HasOutgoingEdgeWith(operation.into()) - } - pub fn has_incoming_edge_with(self, operation: EdgeOperation) -> NodeOperation { - NodeOperation::HasIncomingEdgeWith(operation.into()) - } - pub fn has_edge_with(self, operation: EdgeOperation) -> NodeOperation { - NodeOperation::HasOutgoingEdgeWith(operation.clone().into()) - .or(NodeOperation::HasIncomingEdgeWith(operation.into())) - } - - pub fn has_neighbor_with(self, operation: NodeOperation) -> NodeOperation { - NodeOperation::HasNeighborWith(Box::new(operation)) - } - pub fn has_neighbor_undirected_with(self, operation: NodeOperation) -> NodeOperation { - NodeOperation::HasNeighborUndirectedWith(Box::new(operation)) - } - - pub fn attribute(self, attribute: impl Into) -> NodeAttributeOperand { - NodeAttributeOperand(attribute.into()) - } - - pub fn index(self) -> NodeIndexOperand { - NodeIndexOperand - } -} - -pub fn node() -> NodeOperand { - NodeOperand -} - -#[derive(Debug, Clone)] -pub enum EdgeIndexInOperand { - Vector(Vec), - Operation(EdgeOperation), -} - -impl> From> for EdgeIndexInOperand { - fn from(value: Vec) -> EdgeIndexInOperand { - EdgeIndexInOperand::Vector(value.into_iter().map(|value| value.into()).collect()) - } -} - -impl From for EdgeIndexInOperand { - fn from(value: EdgeOperation) -> Self { - EdgeIndexInOperand::Operation(value) - } -} - -pub(super) trait IntoVecEdgeIndex { - fn into_vec_edge_index(self, medrecord: &MedRecord) -> Vec; -} - -impl IntoVecEdgeIndex for EdgeIndexInOperand { - fn into_vec_edge_index(self, medrecord: &MedRecord) -> Vec { - match self { - EdgeIndexInOperand::Vector(value) => value, - EdgeIndexInOperand::Operation(operation) => operation - .evaluate(medrecord, medrecord.edge_indices()) - .copied() - .collect(), - } - } -} - -#[derive(Debug, Clone)] -pub struct EdgeIndexOperand; - -impl EdgeIndexOperand { - pub fn greater(self, operand: EdgeIndex) -> EdgeOperation { - EdgeOperation::Index(EdgeIndexOperation::Gt(operand)) - } - pub fn less(self, operand: EdgeIndex) -> EdgeOperation { - EdgeOperation::Index(EdgeIndexOperation::Lt(operand)) - } - pub fn greater_or_equal(self, operand: EdgeIndex) -> EdgeOperation { - EdgeOperation::Index(EdgeIndexOperation::Gte(operand)) - } - pub fn less_or_equal(self, operand: EdgeIndex) -> EdgeOperation { - EdgeOperation::Index(EdgeIndexOperation::Lte(operand)) - } - - pub fn equal(self, operand: EdgeIndex) -> EdgeOperation { - EdgeOperation::Index(EdgeIndexOperation::Eq(operand)) - } - pub fn not_equal(self, operand: EdgeIndex) -> EdgeOperation { - self.equal(operand).not() - } - - pub fn r#in(self, operand: impl Into) -> EdgeOperation { - EdgeOperation::Index(EdgeIndexOperation::In(Box::new(operand.into()))) - } - pub fn not_in(self, operand: impl Into) -> EdgeOperation { - self.r#in(operand).not() - } -} - -#[derive(Debug, Clone)] -pub struct EdgeOperand; - -impl EdgeOperand { - pub fn connected_target(self, operand: impl Into) -> EdgeOperation { - EdgeOperation::ConnectedSource(operand.into()) - } - - pub fn connected_source(self, operand: impl Into) -> EdgeOperation { - EdgeOperation::ConnectedTarget(operand.into()) - } - - pub fn connected(self, operand: impl Into) -> EdgeOperation { - let attribute = operand.into(); - - EdgeOperation::ConnectedSource(attribute.clone()) - .or(EdgeOperation::ConnectedTarget(attribute)) - } - - pub fn in_group(self, operand: impl Into) -> EdgeOperation { - EdgeOperation::InGroup(operand.into()) - } - - pub fn has_attribute(self, operand: impl Into) -> EdgeOperation { - EdgeOperation::HasAttribute(operand.into()) - } - - pub fn connected_source_with(self, operation: NodeOperation) -> EdgeOperation { - EdgeOperation::ConnectedSourceWith(operation.into()) - } - - pub fn connected_target_with(self, operation: NodeOperation) -> EdgeOperation { - EdgeOperation::ConnectedTargetWith(operation.into()) - } - - pub fn connected_with(self, operation: NodeOperation) -> EdgeOperation { - EdgeOperation::ConnectedSourceWith(operation.clone().into()) - .or(EdgeOperation::ConnectedTargetWith(operation.into())) - } - - pub fn has_parallel_edges_with(self, operation: EdgeOperation) -> EdgeOperation { - EdgeOperation::HasParallelEdgesWith(Box::new(operation)) - } - - pub fn has_parallel_edges_with_self_comparison( - self, - operation: EdgeOperation, - ) -> EdgeOperation { - EdgeOperation::HasParallelEdgesWithSelfComparison(Box::new(operation)) - } - - pub fn attribute(self, attribute: impl Into) -> EdgeAttributeOperand { - EdgeAttributeOperand(attribute.into()) - } - - pub fn index(self) -> EdgeIndexOperand { - EdgeIndexOperand - } -} - -pub fn edge() -> EdgeOperand { - EdgeOperand -} diff --git a/crates/medmodels-core/src/medrecord/querying/selection.rs b/crates/medmodels-core/src/medrecord/querying/selection.rs deleted file mode 100644 index 82e8356e..00000000 --- a/crates/medmodels-core/src/medrecord/querying/selection.rs +++ /dev/null @@ -1,1741 +0,0 @@ -use super::operation::{EdgeOperation, NodeOperation, Operation}; -use crate::medrecord::{EdgeIndex, MedRecord, NodeIndex}; - -#[derive(Debug)] -pub struct NodeSelection<'a> { - medrecord: &'a MedRecord, - operation: NodeOperation, -} - -impl<'a> NodeSelection<'a> { - pub fn new(medrecord: &'a MedRecord, operation: NodeOperation) -> Self { - Self { - medrecord, - operation, - } - } - - pub fn iter(self) -> impl Iterator { - self.operation - .evaluate(self.medrecord, self.medrecord.node_indices()) - } - - pub fn collect>(self) -> B { - FromIterator::from_iter(self.iter()) - } -} - -#[derive(Debug)] -pub struct EdgeSelection<'a> { - medrecord: &'a MedRecord, - operation: EdgeOperation, -} - -impl<'a> EdgeSelection<'a> { - pub fn new(medrecord: &'a MedRecord, operation: EdgeOperation) -> Self { - Self { - medrecord, - operation, - } - } - - pub fn iter(self) -> impl Iterator { - self.operation - .evaluate(self.medrecord, self.medrecord.edge_indices()) - } - - pub fn collect>(self) -> B { - FromIterator::from_iter(self.iter()) - } -} - -#[cfg(test)] -mod test { - use crate::medrecord::{edge, node, Attributes, MedRecord, MedRecordAttribute, NodeIndex}; - use std::collections::HashMap; - - fn create_nodes() -> Vec<(NodeIndex, Attributes)> { - vec![ - ( - "0".into(), - HashMap::from([ - ("lorem".into(), "ipsum".into()), - ("dolor".into(), " ipsum ".into()), - ("test".into(), "Ipsum".into()), - ("integer".into(), 1.into()), - ("float".into(), 0.5.into()), - ]), - ), - ( - "1".into(), - HashMap::from([("amet".into(), "consectetur".into())]), - ), - ( - "2".into(), - HashMap::from([("adipiscing".into(), "elit".into())]), - ), - ("3".into(), HashMap::new()), - ] - } - - fn create_edges() -> Vec<(NodeIndex, NodeIndex, Attributes)> { - vec![ - ( - "0".into(), - "1".into(), - HashMap::from([ - ("sed".into(), "do".into()), - ("eiusmod".into(), "tempor".into()), - ("dolor".into(), " do ".into()), - ("test".into(), "DO".into()), - ]), - ), - ( - "1".into(), - "2".into(), - HashMap::from([("incididunt".into(), "ut".into())]), - ), - ( - "0".into(), - "2".into(), - HashMap::from([ - ("test".into(), 1.into()), - ("integer".into(), 1.into()), - ("float".into(), 0.5.into()), - ]), - ), - ( - "0".into(), - "2".into(), - HashMap::from([("test".into(), 0.into())]), - ), - ] - } - - fn create_medrecord() -> MedRecord { - let nodes = create_nodes(); - let edges = create_edges(); - - MedRecord::from_tuples(nodes, Some(edges), None).unwrap() - } - - #[test] - fn test_iter() { - let medrecord = create_medrecord(); - - assert_eq!( - 1, - medrecord - .select_nodes(node().has_attribute("lorem")) - .iter() - .count(), - ); - - assert_eq!( - 1, - medrecord - .select_edges(edge().has_attribute("sed")) - .iter() - .count(), - ); - } - - #[test] - fn test_collect() { - let medrecord = create_medrecord(); - - assert_eq!( - vec![&MedRecordAttribute::from("0")], - medrecord - .select_nodes(node().has_attribute("lorem")) - .collect::>(), - ); - - assert_eq!( - vec![&0], - medrecord - .select_edges(edge().has_attribute("sed")) - .collect::>(), - ); - } - - #[test] - fn test_select_nodes_node() { - let mut medrecord = create_medrecord(); - - medrecord - .add_group("test".into(), Some(vec!["0".into()]), None) - .unwrap(); - - // Node in group - assert_eq!( - 1, - medrecord - .select_nodes(node().in_group("test")) - .iter() - .count(), - ); - - // Node has attribute - assert_eq!( - 1, - medrecord - .select_nodes(node().has_attribute("lorem")) - .iter() - .count(), - ); - - // Node has outgoing edge with - assert_eq!( - 1, - medrecord - .select_nodes(node().has_outgoing_edge_with(edge().index().equal(0))) - .iter() - .count(), - ); - - // Node has incoming edge with - assert_eq!( - 1, - medrecord - .select_nodes(node().has_incoming_edge_with(edge().index().equal(0))) - .iter() - .count(), - ); - - // Node has edge with - assert_eq!( - 2, - medrecord - .select_nodes(node().has_edge_with(edge().index().equal(0))) - .iter() - .count(), - ); - - // Node has neighbor with - assert_eq!( - 2, - medrecord - .select_nodes(node().has_neighbor_with(node().index().equal("2"))) - .iter() - .count(), - ); - assert_eq!( - 1, - medrecord - .select_nodes(node().has_neighbor_with(node().index().equal("1"))) - .iter() - .count(), - ); - - // Node has undirected neighbor with - assert_eq!( - 2, - medrecord - .select_nodes(node().has_neighbor_undirected_with(node().index().equal("1"))) - .iter() - .count(), - ); - } - - #[test] - fn test_select_nodes_node_index() { - let medrecord = create_medrecord(); - - // Index greater - assert_eq!( - 2, - medrecord - .select_nodes(node().index().greater("1")) - .iter() - .count(), - ); - - // Index less - assert_eq!( - 1, - medrecord - .select_nodes(node().index().less("1")) - .iter() - .count(), - ); - - // Index greater or equal - assert_eq!( - 3, - medrecord - .select_nodes(node().index().greater_or_equal("1")) - .iter() - .count(), - ); - - // Index less or equal - assert_eq!( - 2, - medrecord - .select_nodes(node().index().less_or_equal("1")) - .iter() - .count(), - ); - - // Index equal - assert_eq!( - 1, - medrecord - .select_nodes(node().index().equal("1")) - .iter() - .count(), - ); - - // Index not equal - assert_eq!( - 3, - medrecord - .select_nodes(node().index().not_equal("1")) - .iter() - .count(), - ); - - // Index in - assert_eq!( - 1, - medrecord - .select_nodes(node().index().r#in(vec!["1"])) - .iter() - .count(), - ); - - // Index in - assert_eq!( - 1, - medrecord - .select_nodes(node().index().r#in(node().has_attribute("lorem"))) - .iter() - .count(), - ); - - // Index not in - assert_eq!( - 3, - medrecord - .select_nodes(node().index().not_in(node().has_attribute("lorem"))) - .iter() - .count(), - ); - - // Index starts with - assert_eq!( - 1, - medrecord - .select_nodes(node().index().starts_with("1")) - .iter() - .count(), - ); - - // Index ends with - assert_eq!( - 1, - medrecord - .select_nodes(node().index().ends_with("1")) - .iter() - .count(), - ); - - // Index contains - assert_eq!( - 1, - medrecord - .select_nodes(node().index().contains("1")) - .iter() - .count(), - ); - } - - #[test] - fn test_select_nodes_node_attribute() { - let medrecord = create_medrecord(); - - // Attribute greater - assert_eq!( - 0, - medrecord - .select_nodes(node().attribute("lorem").greater("ipsum")) - .iter() - .count(), - ); - - // Attribute less - assert_eq!( - 0, - medrecord - .select_nodes(node().attribute("lorem").less("ipsum")) - .iter() - .count(), - ); - - // Attribute greater or equal - assert_eq!( - 1, - medrecord - .select_nodes(node().attribute("lorem").greater_or_equal("ipsum")) - .iter() - .count(), - ); - - // Attribute less or equal - assert_eq!( - 1, - medrecord - .select_nodes(node().attribute("lorem").less_or_equal("ipsum")) - .iter() - .count(), - ); - - // Attribute equal - assert_eq!( - 1, - medrecord - .select_nodes(node().attribute("lorem").equal("ipsum")) - .iter() - .count(), - ); - - // Attribute not equal - assert_eq!( - 0, - medrecord - .select_nodes(node().attribute("lorem").not_equal("ipsum")) - .iter() - .count(), - ); - - // Attribute in - assert_eq!( - 1, - medrecord - .select_nodes(node().attribute("lorem").r#in(vec!["ipsum"])) - .iter() - .count(), - ); - - // Attribute not in - assert_eq!( - 0, - medrecord - .select_nodes(node().attribute("lorem").not_in(vec!["ipsum"])) - .iter() - .count(), - ); - - // Attribute starts with - assert_eq!( - 1, - medrecord - .select_nodes(node().attribute("lorem").starts_with("ip")) - .iter() - .count(), - ); - - // Attribute ends with - assert_eq!( - 1, - medrecord - .select_nodes(node().attribute("lorem").ends_with("um")) - .iter() - .count(), - ); - - // Attribute contains - assert_eq!( - 1, - medrecord - .select_nodes(node().attribute("lorem").contains("su")) - .iter() - .count(), - ); - - // Attribute compare to attribute - assert_eq!( - 1, - medrecord - .select_nodes(node().attribute("lorem").equal(node().attribute("lorem"))) - .iter() - .count(), - ); - - // Attribute compare to attribute - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("lorem") - .not_equal(node().attribute("lorem")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute add - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("lorem") - .equal(node().attribute("lorem").add("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute add - assert_eq!( - 1, - medrecord - .select_nodes( - node() - .attribute("lorem") - .not_equal(node().attribute("lorem").add("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute sub - // Returns nothing because can't sub a string - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("lorem") - .equal(node().attribute("lorem").sub("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute sub - // Doesn't work because can't sub a string - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("lorem") - .not_equal(node().attribute("lorem").sub("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute sub - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("integer") - .equal(node().attribute("integer").sub(10)) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute sub - assert_eq!( - 1, - medrecord - .select_nodes( - node() - .attribute("integer") - .not_equal(node().attribute("integer").sub(10)) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute mul - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("lorem") - .equal(node().attribute("lorem").mul(2)) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute mul - assert_eq!( - 1, - medrecord - .select_nodes( - node() - .attribute("lorem") - .not_equal(node().attribute("lorem").mul(2)) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute div - // Returns nothing because can't div a string - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("lorem") - .equal(node().attribute("lorem").div("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute div - // Doesn't work because can't div a string - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("lorem") - .not_equal(node().attribute("lorem").div("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute div - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("integer") - .equal(node().attribute("integer").div(2)) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute div - assert_eq!( - 1, - medrecord - .select_nodes( - node() - .attribute("integer") - .not_equal(node().attribute("integer").div(2)) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute pow - // Returns nothing because can't pow a string - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("lorem") - .equal(node().attribute("lorem").pow("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute pow - // Doesn't work because can't pow a string - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("lorem") - .not_equal(node().attribute("lorem").pow("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute pow - assert_eq!( - 1, - medrecord - .select_nodes( - node() - .attribute("integer") - .equal(node().attribute("integer").pow(2)) // 1 ** 2 = 1 - ) - .iter() - .count(), - ); - - // Attribute compare to attribute pow - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("integer") - .not_equal(node().attribute("integer").pow(2)) // 1 ** 2 = 1 - ) - .iter() - .count(), - ); - - // Attribute compare to attribute mod - // Returns nothing because can't mod a string - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("lorem") - .equal(node().attribute("lorem").r#mod("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute mod - // Doesn't work because can't mod a string - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("lorem") - .not_equal(node().attribute("lorem").r#mod("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute mod - assert_eq!( - 1, - medrecord - .select_nodes( - node() - .attribute("integer") - .equal(node().attribute("integer").r#mod(2)) // 1 % 2 = 1 - ) - .iter() - .count(), - ); - - // Attribute compare to attribute mod - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("integer") - .not_equal(node().attribute("integer").r#mod(2)) // 1 % 2 = 1 - ) - .iter() - .count(), - ); - - // Attribute compare to attribute round - assert_eq!( - 1, - medrecord - .select_nodes( - node() - .attribute("lorem") - .equal(node().attribute("lorem").round()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute round - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("lorem") - .not_equal(node().attribute("lorem").round()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute round - assert_eq!( - 1, - medrecord - .select_nodes( - node() - .attribute("integer") - .equal(node().attribute("float").round()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute round - assert_eq!( - 1, - medrecord - .select_nodes( - node() - .attribute("float") - .not_equal(node().attribute("float").round()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute ceil - assert_eq!( - 1, - medrecord - .select_nodes( - node() - .attribute("integer") - .equal(node().attribute("float").ceil()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute ceil - assert_eq!( - 1, - medrecord - .select_nodes( - node() - .attribute("float") - .not_equal(node().attribute("float").ceil()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute floor - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("integer") - .equal(node().attribute("float").floor()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute floor - assert_eq!( - 1, - medrecord - .select_nodes( - node() - .attribute("float") - .not_equal(node().attribute("float").floor()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute abs - assert_eq!( - 1, - medrecord - .select_nodes( - node() - .attribute("integer") - .equal(node().attribute("integer").abs()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute sqrt - assert_eq!( - 1, - medrecord - .select_nodes( - node() - .attribute("integer") - .equal(node().attribute("integer").sqrt()) // sqrt(1) = 1 - ) - .iter() - .count(), - ); - - // Attribute compare to attribute trim - assert_eq!( - 1, - medrecord - .select_nodes( - node() - .attribute("lorem") - .equal(node().attribute("dolor").trim()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute trim_start - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("lorem") - .equal(node().attribute("dolor").trim_start()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute trim_end - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("lorem") - .equal(node().attribute("dolor").trim_end()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute lowercase - assert_eq!( - 1, - medrecord - .select_nodes( - node() - .attribute("lorem") - .equal(node().attribute("test").lowercase()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute uppercase - assert_eq!( - 0, - medrecord - .select_nodes( - node() - .attribute("lorem") - .equal(node().attribute("test").uppercase()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute slice - assert_eq!( - 1, - medrecord - .select_nodes( - node() - .attribute("lorem") - .equal(node().attribute("dolor").slice(2..7)) - ) - .iter() - .count(), - ); - } - - #[test] - fn test_select_edges_edge() { - let mut medrecord = create_medrecord(); - - medrecord - .add_group("test".into(), None, Some(vec![0])) - .unwrap(); - - // Edge connected to target - assert_eq!( - 3, - medrecord - .select_edges(edge().connected_target("2")) - .iter() - .count(), - ); - - // Edge connected to source - assert_eq!( - 3, - medrecord - .select_edges(edge().connected_source("0")) - .iter() - .count(), - ); - - // Edge connected - assert_eq!( - 2, - medrecord.select_edges(edge().connected("1")).iter().count(), - ); - - // Edge in group - assert_eq!( - 1, - medrecord - .select_edges(edge().in_group("test")) - .iter() - .count(), - ); - - // Edge has attribute - assert_eq!( - 1, - medrecord - .select_edges(edge().has_attribute("sed")) - .iter() - .count(), - ); - - // Edge connected to target with - assert_eq!( - 1, - medrecord - .select_edges(edge().connected_target_with(node().index().equal("1"))) - .iter() - .count(), - ); - - // Edge connected to source with - assert_eq!( - 3, - medrecord - .select_edges(edge().connected_source_with(node().index().equal("0"))) - .iter() - .count(), - ); - - // Edge connected with - assert_eq!( - 2, - medrecord - .select_edges(edge().connected_with(node().index().equal("1"))) - .iter() - .count(), - ); - - // Edge has parallel edges with - assert_eq!( - 2, - medrecord - .select_edges(edge().has_parallel_edges_with(edge().has_attribute("test"))) - .iter() - .count(), - ); - - // Edge has parallel edges with self comparison - assert_eq!( - 1, - medrecord - .select_edges( - edge().has_parallel_edges_with_self_comparison( - edge() - .attribute("test") - .equal(edge().attribute("test").sub(1)) - ) - ) - .iter() - .count(), - ); - } - - #[test] - fn test_select_edges_edge_index() { - let medrecord = create_medrecord(); - - // Index greater - assert_eq!( - 2, - medrecord - .select_edges(edge().index().greater(1)) - .iter() - .count(), - ); - - // Index less - assert_eq!( - 1, - medrecord - .select_edges(edge().index().less(1)) - .iter() - .count(), - ); - - // Index greater or equal - assert_eq!( - 3, - medrecord - .select_edges(edge().index().greater_or_equal(1)) - .iter() - .count(), - ); - - // Index less or equal - assert_eq!( - 2, - medrecord - .select_edges(edge().index().less_or_equal(1)) - .iter() - .count(), - ); - - // Index equal - assert_eq!( - 1, - medrecord - .select_edges(edge().index().equal(1)) - .iter() - .count(), - ); - - // Index not equal - assert_eq!( - 3, - medrecord - .select_edges(edge().index().not_equal(1)) - .iter() - .count(), - ); - - // Index in - assert_eq!( - 1, - medrecord - .select_edges(edge().index().r#in(vec![1_usize])) - .iter() - .count(), - ); - - // Index in - assert_eq!( - 1, - medrecord - .select_edges(edge().index().r#in(edge().has_attribute("sed"))) - .iter() - .count(), - ); - - // Index not in - assert_eq!( - 3, - medrecord - .select_edges(edge().index().not_in(edge().has_attribute("sed"))) - .iter() - .count(), - ); - } - - #[test] - fn test_select_edges_edge_attribute() { - let medrecord = create_medrecord(); - - // Attribute greater - assert_eq!( - 0, - medrecord - .select_edges(edge().attribute("sed").greater("do")) - .iter() - .count(), - ); - - // Attribute less - assert_eq!( - 0, - medrecord - .select_edges(edge().attribute("sed").less("do")) - .iter() - .count(), - ); - - // Attribute greater or equal - assert_eq!( - 1, - medrecord - .select_edges(edge().attribute("sed").greater_or_equal("do")) - .iter() - .count(), - ); - - // Attribute less or equal - assert_eq!( - 1, - medrecord - .select_edges(edge().attribute("sed").less_or_equal("do")) - .iter() - .count(), - ); - - // Attribute equal - assert_eq!( - 1, - medrecord - .select_edges(edge().attribute("sed").equal("do")) - .iter() - .count(), - ); - - // Attribute not equal - assert_eq!( - 0, - medrecord - .select_edges(edge().attribute("sed").not_equal("do")) - .iter() - .count(), - ); - - // Attribute in - assert_eq!( - 1, - medrecord - .select_edges(edge().attribute("sed").r#in(vec!["do"])) - .iter() - .count(), - ); - - // Attribute not in - assert_eq!( - 0, - medrecord - .select_edges(edge().attribute("sed").not_in(vec!["do"])) - .iter() - .count(), - ); - - // Attribute starts with - assert_eq!( - 1, - medrecord - .select_edges(edge().attribute("sed").starts_with("d")) - .iter() - .count(), - ); - - // Attribute ends with - assert_eq!( - 1, - medrecord - .select_edges(edge().attribute("sed").ends_with("o")) - .iter() - .count(), - ); - - // Attribute contains - assert_eq!( - 1, - medrecord - .select_edges(edge().attribute("sed").contains("do")) - .iter() - .count(), - ); - - // Attribute compare to attribute - assert_eq!( - 1, - medrecord - .select_edges(edge().attribute("sed").equal(edge().attribute("sed"))) - .iter() - .count(), - ); - - // Attribute compare to attribute - assert_eq!( - 0, - medrecord - .select_edges(edge().attribute("sed").not_equal(edge().attribute("sed"))) - .iter() - .count(), - ); - - // Attribute compare to attribute add - assert_eq!( - 0, - medrecord - .select_edges( - edge() - .attribute("sed") - .equal(edge().attribute("sed").add("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute add - assert_eq!( - 1, - medrecord - .select_edges( - edge() - .attribute("sed") - .not_equal(edge().attribute("sed").add("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute sub - // Returns nothing because can't sub a string - assert_eq!( - 0, - medrecord - .select_edges( - edge() - .attribute("sed") - .equal(edge().attribute("sed").sub("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute sub - // Doesn't work because can't sub a string - assert_eq!( - 0, - medrecord - .select_edges( - edge() - .attribute("sed") - .not_equal(edge().attribute("sed").sub("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute sub - assert_eq!( - 0, - medrecord - .select_edges( - edge() - .attribute("integer") - .equal(edge().attribute("integer").sub(10)) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute sub - assert_eq!( - 1, - medrecord - .select_edges( - edge() - .attribute("integer") - .not_equal(edge().attribute("integer").sub(10)) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute mul - assert_eq!( - 0, - medrecord - .select_edges( - edge() - .attribute("sed") - .equal(edge().attribute("sed").mul(2)) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute mul - assert_eq!( - 1, - medrecord - .select_edges( - edge() - .attribute("sed") - .not_equal(edge().attribute("sed").mul(2)) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute div - // Returns nothing because can't div a string - assert_eq!( - 0, - medrecord - .select_edges( - edge() - .attribute("sed") - .equal(edge().attribute("sed").div("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute div - // Doesn't work because can't div a string - assert_eq!( - 0, - medrecord - .select_edges( - edge() - .attribute("sed") - .not_equal(edge().attribute("sed").div("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute div - assert_eq!( - 0, - medrecord - .select_edges( - edge() - .attribute("integer") - .equal(edge().attribute("integer").div(2)) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute div - assert_eq!( - 1, - medrecord - .select_edges( - edge() - .attribute("integer") - .not_equal(edge().attribute("integer").div(2)) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute pow - // Returns nothing because can't pow a string - assert_eq!( - 0, - medrecord - .select_edges( - edge() - .attribute("lorem") - .equal(edge().attribute("lorem").pow("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute pow - // Doesn't work because can't pow a string - assert_eq!( - 0, - medrecord - .select_edges( - edge() - .attribute("lorem") - .not_equal(edge().attribute("lorem").pow("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute pow - assert_eq!( - 1, - medrecord - .select_edges( - edge() - .attribute("integer") - .equal(edge().attribute("integer").pow(2)) // 1 ** 2 = 1 - ) - .iter() - .count(), - ); - - // Attribute compare to attribute pow - assert_eq!( - 0, - medrecord - .select_edges( - edge() - .attribute("integer") - .not_equal(edge().attribute("integer").pow(2)) // 1 ** 2 = 1 - ) - .iter() - .count(), - ); - - // Attribute compare to attribute mod - // Returns nothing because can't mod a string - assert_eq!( - 0, - medrecord - .select_edges( - edge() - .attribute("lorem") - .equal(edge().attribute("lorem").r#mod("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute mod - // Doesn't work because can't mod a string - assert_eq!( - 0, - medrecord - .select_edges( - edge() - .attribute("lorem") - .not_equal(edge().attribute("lorem").r#mod("10")) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute mod - assert_eq!( - 1, - medrecord - .select_edges( - edge() - .attribute("integer") - .equal(edge().attribute("integer").r#mod(2)) // 1 % 2 = 1 - ) - .iter() - .count(), - ); - - // Attribute compare to attribute mod - assert_eq!( - 0, - medrecord - .select_edges( - edge() - .attribute("integer") - .not_equal(edge().attribute("integer").r#mod(2)) // 1 % 2 = 1 - ) - .iter() - .count(), - ); - - // Attribute compare to attribute round - assert_eq!( - 1, - medrecord - .select_edges( - edge() - .attribute("sed") - .equal(edge().attribute("sed").round()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute round - assert_eq!( - 0, - medrecord - .select_edges( - edge() - .attribute("sed") - .not_equal(edge().attribute("sed").round()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute round - assert_eq!( - 1, - medrecord - .select_edges( - edge() - .attribute("integer") - .equal(edge().attribute("float").round()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute round - assert_eq!( - 1, - medrecord - .select_edges( - edge() - .attribute("float") - .not_equal(edge().attribute("float").round()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute ceil - assert_eq!( - 1, - medrecord - .select_edges( - edge() - .attribute("integer") - .equal(edge().attribute("float").ceil()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute ceil - assert_eq!( - 1, - medrecord - .select_edges( - edge() - .attribute("float") - .not_equal(edge().attribute("float").ceil()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute floor - assert_eq!( - 0, - medrecord - .select_edges( - edge() - .attribute("integer") - .equal(edge().attribute("float").floor()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute floor - assert_eq!( - 1, - medrecord - .select_edges( - edge() - .attribute("float") - .not_equal(edge().attribute("float").floor()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute abs - assert_eq!( - 1, - medrecord - .select_edges( - edge() - .attribute("integer") - .equal(edge().attribute("integer").abs()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute sqrt - assert_eq!( - 1, - medrecord - .select_edges( - edge() - .attribute("integer") - .equal(edge().attribute("integer").sqrt()) // sqrt(1) = 1 - ) - .iter() - .count(), - ); - - // Attribute compare to attribute trim - assert_eq!( - 1, - medrecord - .select_edges( - edge() - .attribute("sed") - .equal(edge().attribute("dolor").trim()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute trim_start - assert_eq!( - 0, - medrecord - .select_edges( - edge() - .attribute("sed") - .equal(edge().attribute("dolor").trim_start()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute trim_end - assert_eq!( - 0, - medrecord - .select_edges( - edge() - .attribute("sed") - .equal(edge().attribute("dolor").trim_end()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute lowercase - assert_eq!( - 1, - medrecord - .select_edges( - edge() - .attribute("sed") - .equal(edge().attribute("test").lowercase()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute uppercase - assert_eq!( - 0, - medrecord - .select_edges( - edge() - .attribute("sed") - .equal(edge().attribute("test").uppercase()) - ) - .iter() - .count(), - ); - - // Attribute compare to attribute slice - assert_eq!( - 1, - medrecord - .select_edges( - edge() - .attribute("sed") - .equal(edge().attribute("dolor").slice(2..4)) - ) - .iter() - .count(), - ); - } -} diff --git a/crates/medmodels-core/src/medrecord/querying/traits.rs b/crates/medmodels-core/src/medrecord/querying/traits.rs new file mode 100644 index 00000000..4e8d33e8 --- /dev/null +++ b/crates/medmodels-core/src/medrecord/querying/traits.rs @@ -0,0 +1,21 @@ +use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard}; + +pub trait DeepClone { + fn deep_clone(&self) -> Self; +} + +pub(crate) trait ReadWriteOrPanic { + fn read_or_panic(&self) -> RwLockReadGuard<'_, T>; + + fn write_or_panic(&self) -> RwLockWriteGuard<'_, T>; +} + +impl ReadWriteOrPanic for RwLock { + fn read_or_panic(&self) -> RwLockReadGuard<'_, T> { + self.read().unwrap() + } + + fn write_or_panic(&self) -> RwLockWriteGuard<'_, T> { + self.write().unwrap() + } +} diff --git a/crates/medmodels-core/src/medrecord/querying/values/mod.rs b/crates/medmodels-core/src/medrecord/querying/values/mod.rs new file mode 100644 index 00000000..bf2e2f4a --- /dev/null +++ b/crates/medmodels-core/src/medrecord/querying/values/mod.rs @@ -0,0 +1,185 @@ +mod operand; +mod operation; + +use super::{ + attributes::{ + self, AttributesTreeOperation, MultipleAttributesOperand, MultipleAttributesOperation, + }, + edges::{EdgeOperand, EdgeOperation}, + nodes::{NodeOperand, NodeOperation}, + BoxedIterator, +}; +use crate::{ + errors::MedRecordResult, + medrecord::{MedRecordAttribute, MedRecordValue}, + MedRecord, +}; +pub use operand::MultipleValuesOperand; +use std::fmt::Display; + +macro_rules! get_attributes { + ($operand:ident, $medrecord:ident, $operation:ident, $multiple_attributes_operand:ident) => {{ + let indices = $operand.evaluate($medrecord)?; + + let attributes = $operation::get_attributes($medrecord, indices); + + let attributes = $multiple_attributes_operand + .context + .evaluate($medrecord, attributes)?; + + let attributes: Box> = + match $multiple_attributes_operand.kind { + attributes::MultipleKind::Max => { + Box::new(AttributesTreeOperation::get_max(attributes)?) + } + attributes::MultipleKind::Min => { + Box::new(AttributesTreeOperation::get_min(attributes)?) + } + attributes::MultipleKind::Count => { + Box::new(AttributesTreeOperation::get_count(attributes)?) + } + attributes::MultipleKind::Sum => { + Box::new(AttributesTreeOperation::get_sum(attributes)?) + } + attributes::MultipleKind::First => { + Box::new(AttributesTreeOperation::get_first(attributes)?) + } + attributes::MultipleKind::Last => { + Box::new(AttributesTreeOperation::get_last(attributes)?) + } + }; + + let attributes = $multiple_attributes_operand.evaluate($medrecord, attributes)?; + + Box::new( + MultipleAttributesOperation::get_values($medrecord, attributes)? + .map(|(_, value)| value), + ) + }}; +} + +#[derive(Debug, Clone)] +pub enum SingleKind { + Max, + Min, + Mean, + Median, + Mode, + Std, + Var, + Count, + Sum, + First, + Last, +} + +#[derive(Debug, Clone)] +pub enum SingleComparisonKind { + GreaterThan, + GreaterThanOrEqualTo, + LessThan, + LessThanOrEqualTo, + EqualTo, + NotEqualTo, + StartsWith, + EndsWith, + Contains, +} + +#[derive(Debug, Clone)] +pub enum MultipleComparisonKind { + IsIn, + IsNotIn, +} + +#[derive(Debug, Clone)] +pub enum BinaryArithmeticKind { + Add, + Sub, + Mul, + Div, + Pow, + Mod, +} + +impl Display for BinaryArithmeticKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + BinaryArithmeticKind::Add => write!(f, "add"), + BinaryArithmeticKind::Sub => write!(f, "sub"), + BinaryArithmeticKind::Mul => write!(f, "mul"), + BinaryArithmeticKind::Div => write!(f, "div"), + BinaryArithmeticKind::Pow => write!(f, "pow"), + BinaryArithmeticKind::Mod => write!(f, "mod"), + } + } +} + +#[derive(Debug, Clone)] +pub enum UnaryArithmeticKind { + Round, + Ceil, + Floor, + Abs, + Sqrt, + Trim, + TrimStart, + TrimEnd, + Lowercase, + Uppercase, +} + +#[allow(clippy::enum_variant_names)] +#[derive(Debug, Clone)] +pub enum Context { + NodeOperand(NodeOperand), + EdgeOperand(EdgeOperand), + MultipleAttributesOperand(MultipleAttributesOperand), +} + +impl Context { + pub(crate) fn get_values<'a>( + &self, + medrecord: &'a MedRecord, + attribute: MedRecordAttribute, + ) -> MedRecordResult> { + Ok(match self { + Self::NodeOperand(node_operand) => { + let node_indices = node_operand.evaluate(medrecord)?; + + Box::new( + NodeOperation::get_values(medrecord, node_indices, attribute) + .map(|(_, value)| value), + ) + } + Self::EdgeOperand(edge_operand) => { + let edge_indices = edge_operand.evaluate(medrecord)?; + + Box::new( + EdgeOperation::get_values(medrecord, edge_indices, attribute) + .map(|(_, value)| value), + ) + } + Self::MultipleAttributesOperand(multiple_attributes_operand) => { + match &multiple_attributes_operand.context.context { + attributes::Context::NodeOperand(node_operand) => { + get_attributes!( + node_operand, + medrecord, + NodeOperation, + multiple_attributes_operand + ) + } + attributes::Context::EdgeOperand(edge_operand) => { + get_attributes!( + edge_operand, + medrecord, + EdgeOperation, + multiple_attributes_operand + ) + } + } + } + }) + } +} diff --git a/crates/medmodels-core/src/medrecord/querying/values/operand.rs b/crates/medmodels-core/src/medrecord/querying/values/operand.rs new file mode 100644 index 00000000..01796ed7 --- /dev/null +++ b/crates/medmodels-core/src/medrecord/querying/values/operand.rs @@ -0,0 +1,590 @@ +use super::{ + operation::{MultipleValuesOperation, SingleValueOperation}, + BinaryArithmeticKind, Context, MultipleComparisonKind, SingleComparisonKind, SingleKind, + UnaryArithmeticKind, +}; +use crate::{ + errors::MedRecordResult, + medrecord::{ + querying::{ + traits::{DeepClone, ReadWriteOrPanic}, + BoxedIterator, + }, + MedRecordAttribute, MedRecordValue, Wrapper, + }, + MedRecord, +}; +use std::hash::Hash; + +macro_rules! implement_value_operation { + ($name:ident, $variant:ident) => { + pub fn $name(&mut self) -> Wrapper { + let operand = + Wrapper::::new(self.deep_clone(), SingleKind::$variant); + + self.operations + .push(MultipleValuesOperation::ValueOperation { + operand: operand.clone(), + }); + + operand + } + }; +} + +macro_rules! implement_single_value_comparison_operation { + ($name:ident, $operation:ident, $kind:ident) => { + pub fn $name>(&mut self, value: V) { + self.operations + .push($operation::SingleValueComparisonOperation { + operand: value.into(), + kind: SingleComparisonKind::$kind, + }); + } + }; +} + +macro_rules! implement_binary_arithmetic_operation { + ($name:ident, $operation:ident, $kind:ident) => { + pub fn $name>(&mut self, value: V) { + self.operations.push($operation::BinaryArithmeticOpration { + operand: value.into(), + kind: BinaryArithmeticKind::$kind, + }); + } + }; +} + +macro_rules! implement_unary_arithmetic_operation { + ($name:ident, $operation:ident, $kind:ident) => { + pub fn $name(&mut self) { + self.operations.push($operation::UnaryArithmeticOperation { + kind: UnaryArithmeticKind::$kind, + }); + } + }; +} + +macro_rules! implement_assertion_operation { + ($name:ident, $operation:expr) => { + pub fn $name(&mut self) { + self.operations.push($operation); + } + }; +} + +macro_rules! implement_wrapper_operand { + ($name:ident) => { + pub fn $name(&self) { + self.0.write_or_panic().$name() + } + }; +} + +macro_rules! implement_wrapper_operand_with_return { + ($name:ident, $return_operand:ident) => { + pub fn $name(&self) -> Wrapper<$return_operand> { + self.0.write_or_panic().$name() + } + }; +} + +macro_rules! implement_wrapper_operand_with_argument { + ($name:ident, $value_type:ty) => { + pub fn $name(&self, value: $value_type) { + self.0.write_or_panic().$name(value) + } + }; +} + +#[derive(Debug, Clone)] +pub enum SingleValueComparisonOperand { + Operand(SingleValueOperand), + Value(MedRecordValue), +} + +impl DeepClone for SingleValueComparisonOperand { + fn deep_clone(&self) -> Self { + match self { + Self::Operand(operand) => Self::Operand(operand.deep_clone()), + Self::Value(value) => Self::Value(value.clone()), + } + } +} + +impl From> for SingleValueComparisonOperand { + fn from(value: Wrapper) -> Self { + Self::Operand(value.0.read_or_panic().deep_clone()) + } +} + +impl From<&Wrapper> for SingleValueComparisonOperand { + fn from(value: &Wrapper) -> Self { + Self::Operand(value.0.read_or_panic().deep_clone()) + } +} + +impl> From for SingleValueComparisonOperand { + fn from(value: V) -> Self { + Self::Value(value.into()) + } +} + +#[derive(Debug, Clone)] +pub enum MultipleValuesComparisonOperand { + Operand(MultipleValuesOperand), + Values(Vec), +} + +impl DeepClone for MultipleValuesComparisonOperand { + fn deep_clone(&self) -> Self { + match self { + Self::Operand(operand) => Self::Operand(operand.deep_clone()), + Self::Values(value) => Self::Values(value.clone()), + } + } +} + +impl From> for MultipleValuesComparisonOperand { + fn from(value: Wrapper) -> Self { + Self::Operand(value.0.read_or_panic().deep_clone()) + } +} + +impl From<&Wrapper> for MultipleValuesComparisonOperand { + fn from(value: &Wrapper) -> Self { + Self::Operand(value.0.read_or_panic().deep_clone()) + } +} + +impl> From> for MultipleValuesComparisonOperand { + fn from(value: Vec) -> Self { + Self::Values(value.into_iter().map(Into::into).collect()) + } +} + +impl + Clone, const N: usize> From<[V; N]> + for MultipleValuesComparisonOperand +{ + fn from(value: [V; N]) -> Self { + value.to_vec().into() + } +} + +#[derive(Debug, Clone)] +pub struct MultipleValuesOperand { + pub(crate) context: Context, + pub(crate) attribute: MedRecordAttribute, + operations: Vec, +} + +impl DeepClone for MultipleValuesOperand { + fn deep_clone(&self) -> Self { + Self { + context: self.context.clone(), + attribute: self.attribute.clone(), + operations: self.operations.iter().map(DeepClone::deep_clone).collect(), + } + } +} + +impl MultipleValuesOperand { + pub(crate) fn new(context: Context, attribute: MedRecordAttribute) -> Self { + Self { + context, + attribute, + operations: Vec::new(), + } + } + + pub(crate) fn evaluate<'a, T: 'a + Eq + Hash>( + &self, + medrecord: &'a MedRecord, + values: impl Iterator + 'a, + ) -> MedRecordResult> { + let values = Box::new(values) as BoxedIterator<(&'a T, MedRecordValue)>; + + self.operations + .iter() + .try_fold(values, |value_tuples, operation| { + operation.evaluate(medrecord, value_tuples) + }) + } + + implement_value_operation!(max, Max); + implement_value_operation!(min, Min); + implement_value_operation!(mean, Mean); + implement_value_operation!(median, Median); + implement_value_operation!(mode, Mode); + implement_value_operation!(std, Std); + implement_value_operation!(var, Var); + implement_value_operation!(count, Count); + implement_value_operation!(sum, Sum); + implement_value_operation!(first, First); + implement_value_operation!(last, Last); + + implement_single_value_comparison_operation!( + greater_than, + MultipleValuesOperation, + GreaterThan + ); + implement_single_value_comparison_operation!( + greater_than_or_equal_to, + MultipleValuesOperation, + GreaterThanOrEqualTo + ); + implement_single_value_comparison_operation!(less_than, MultipleValuesOperation, LessThan); + implement_single_value_comparison_operation!( + less_than_or_equal_to, + MultipleValuesOperation, + LessThanOrEqualTo + ); + implement_single_value_comparison_operation!(equal_to, MultipleValuesOperation, EqualTo); + implement_single_value_comparison_operation!(not_equal_to, MultipleValuesOperation, NotEqualTo); + implement_single_value_comparison_operation!(starts_with, MultipleValuesOperation, StartsWith); + implement_single_value_comparison_operation!(ends_with, MultipleValuesOperation, EndsWith); + implement_single_value_comparison_operation!(contains, MultipleValuesOperation, Contains); + + pub fn is_in>(&mut self, values: V) { + self.operations + .push(MultipleValuesOperation::MultipleValuesComparisonOperation { + operand: values.into(), + kind: MultipleComparisonKind::IsIn, + }); + } + + pub fn is_not_in>(&mut self, values: V) { + self.operations + .push(MultipleValuesOperation::MultipleValuesComparisonOperation { + operand: values.into(), + kind: MultipleComparisonKind::IsNotIn, + }); + } + + implement_binary_arithmetic_operation!(add, MultipleValuesOperation, Add); + implement_binary_arithmetic_operation!(sub, MultipleValuesOperation, Sub); + implement_binary_arithmetic_operation!(mul, MultipleValuesOperation, Mul); + implement_binary_arithmetic_operation!(div, MultipleValuesOperation, Div); + implement_binary_arithmetic_operation!(pow, MultipleValuesOperation, Pow); + implement_binary_arithmetic_operation!(r#mod, MultipleValuesOperation, Mod); + + implement_unary_arithmetic_operation!(round, MultipleValuesOperation, Round); + implement_unary_arithmetic_operation!(ceil, MultipleValuesOperation, Ceil); + implement_unary_arithmetic_operation!(floor, MultipleValuesOperation, Floor); + implement_unary_arithmetic_operation!(abs, MultipleValuesOperation, Abs); + implement_unary_arithmetic_operation!(sqrt, MultipleValuesOperation, Sqrt); + implement_unary_arithmetic_operation!(trim, MultipleValuesOperation, Trim); + implement_unary_arithmetic_operation!(trim_start, MultipleValuesOperation, TrimStart); + implement_unary_arithmetic_operation!(trim_end, MultipleValuesOperation, TrimEnd); + implement_unary_arithmetic_operation!(lowercase, MultipleValuesOperation, Lowercase); + implement_unary_arithmetic_operation!(uppercase, MultipleValuesOperation, Uppercase); + + pub fn slice(&mut self, start: usize, end: usize) { + self.operations + .push(MultipleValuesOperation::Slice(start..end)); + } + + implement_assertion_operation!(is_string, MultipleValuesOperation::IsString); + implement_assertion_operation!(is_int, MultipleValuesOperation::IsInt); + implement_assertion_operation!(is_float, MultipleValuesOperation::IsFloat); + implement_assertion_operation!(is_bool, MultipleValuesOperation::IsBool); + implement_assertion_operation!(is_datetime, MultipleValuesOperation::IsDateTime); + implement_assertion_operation!(is_null, MultipleValuesOperation::IsNull); + implement_assertion_operation!(is_max, MultipleValuesOperation::IsMax); + implement_assertion_operation!(is_min, MultipleValuesOperation::IsMin); + + pub fn either_or(&mut self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + let mut either_operand = + Wrapper::::new(self.context.clone(), self.attribute.clone()); + let mut or_operand = + Wrapper::::new(self.context.clone(), self.attribute.clone()); + + either_query(&mut either_operand); + or_query(&mut or_operand); + + self.operations.push(MultipleValuesOperation::EitherOr { + either: either_operand, + or: or_operand, + }); + } +} + +impl Wrapper { + pub(crate) fn new(context: Context, attribute: MedRecordAttribute) -> Self { + MultipleValuesOperand::new(context, attribute).into() + } + + pub(crate) fn evaluate<'a, T: 'a + Eq + Hash>( + &self, + medrecord: &'a MedRecord, + values: impl Iterator + 'a, + ) -> MedRecordResult> { + self.0.read_or_panic().evaluate(medrecord, values) + } + + implement_wrapper_operand_with_return!(max, SingleValueOperand); + implement_wrapper_operand_with_return!(min, SingleValueOperand); + implement_wrapper_operand_with_return!(mean, SingleValueOperand); + implement_wrapper_operand_with_return!(median, SingleValueOperand); + implement_wrapper_operand_with_return!(mode, SingleValueOperand); + implement_wrapper_operand_with_return!(std, SingleValueOperand); + implement_wrapper_operand_with_return!(var, SingleValueOperand); + implement_wrapper_operand_with_return!(count, SingleValueOperand); + implement_wrapper_operand_with_return!(sum, SingleValueOperand); + implement_wrapper_operand_with_return!(first, SingleValueOperand); + implement_wrapper_operand_with_return!(last, SingleValueOperand); + + implement_wrapper_operand_with_argument!(greater_than, impl Into); + implement_wrapper_operand_with_argument!( + greater_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(less_than, impl Into); + implement_wrapper_operand_with_argument!( + less_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(equal_to, impl Into); + implement_wrapper_operand_with_argument!(not_equal_to, impl Into); + implement_wrapper_operand_with_argument!(starts_with, impl Into); + implement_wrapper_operand_with_argument!(ends_with, impl Into); + implement_wrapper_operand_with_argument!(contains, impl Into); + implement_wrapper_operand_with_argument!(is_in, impl Into); + implement_wrapper_operand_with_argument!(is_not_in, impl Into); + implement_wrapper_operand_with_argument!(add, impl Into); + implement_wrapper_operand_with_argument!(sub, impl Into); + implement_wrapper_operand_with_argument!(mul, impl Into); + implement_wrapper_operand_with_argument!(div, impl Into); + implement_wrapper_operand_with_argument!(pow, impl Into); + implement_wrapper_operand_with_argument!(r#mod, impl Into); + + implement_wrapper_operand!(round); + implement_wrapper_operand!(ceil); + implement_wrapper_operand!(floor); + implement_wrapper_operand!(abs); + implement_wrapper_operand!(sqrt); + implement_wrapper_operand!(trim); + implement_wrapper_operand!(trim_start); + implement_wrapper_operand!(trim_end); + implement_wrapper_operand!(lowercase); + implement_wrapper_operand!(uppercase); + + pub fn slice(&self, start: usize, end: usize) { + self.0.write_or_panic().slice(start, end) + } + + implement_wrapper_operand!(is_string); + implement_wrapper_operand!(is_int); + implement_wrapper_operand!(is_float); + implement_wrapper_operand!(is_bool); + implement_wrapper_operand!(is_datetime); + implement_wrapper_operand!(is_null); + implement_wrapper_operand!(is_max); + implement_wrapper_operand!(is_min); + + pub fn either_or(&self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + self.0.write_or_panic().either_or(either_query, or_query); + } +} + +#[derive(Debug, Clone)] +pub struct SingleValueOperand { + pub(crate) context: MultipleValuesOperand, + pub(crate) kind: SingleKind, + operations: Vec, +} + +impl DeepClone for SingleValueOperand { + fn deep_clone(&self) -> Self { + Self { + context: self.context.deep_clone(), + kind: self.kind.clone(), + operations: self.operations.iter().map(DeepClone::deep_clone).collect(), + } + } +} + +impl SingleValueOperand { + pub(crate) fn new(context: MultipleValuesOperand, kind: SingleKind) -> Self { + Self { + context, + kind, + operations: Vec::new(), + } + } + + pub(crate) fn evaluate( + &self, + medrecord: &MedRecord, + value: MedRecordValue, + ) -> MedRecordResult> { + self.operations + .iter() + .try_fold(Some(value), |value, operation| { + if let Some(value) = value { + operation.evaluate(medrecord, value) + } else { + Ok(None) + } + }) + } + + implement_single_value_comparison_operation!(greater_than, SingleValueOperation, GreaterThan); + implement_single_value_comparison_operation!( + greater_than_or_equal_to, + SingleValueOperation, + GreaterThanOrEqualTo + ); + implement_single_value_comparison_operation!(less_than, SingleValueOperation, LessThan); + implement_single_value_comparison_operation!( + less_than_or_equal_to, + SingleValueOperation, + LessThanOrEqualTo + ); + implement_single_value_comparison_operation!(equal_to, SingleValueOperation, EqualTo); + implement_single_value_comparison_operation!(not_equal_to, SingleValueOperation, NotEqualTo); + implement_single_value_comparison_operation!(starts_with, SingleValueOperation, StartsWith); + implement_single_value_comparison_operation!(ends_with, SingleValueOperation, EndsWith); + implement_single_value_comparison_operation!(contains, SingleValueOperation, Contains); + + pub fn is_in>(&mut self, values: V) { + self.operations + .push(SingleValueOperation::MultipleValuesComparisonOperation { + operand: values.into(), + kind: MultipleComparisonKind::IsIn, + }); + } + + pub fn is_not_in>(&mut self, values: V) { + self.operations + .push(SingleValueOperation::MultipleValuesComparisonOperation { + operand: values.into(), + kind: MultipleComparisonKind::IsNotIn, + }); + } + + implement_binary_arithmetic_operation!(add, SingleValueOperation, Add); + implement_binary_arithmetic_operation!(sub, SingleValueOperation, Sub); + implement_binary_arithmetic_operation!(mul, SingleValueOperation, Mul); + implement_binary_arithmetic_operation!(div, SingleValueOperation, Div); + implement_binary_arithmetic_operation!(pow, SingleValueOperation, Pow); + implement_binary_arithmetic_operation!(r#mod, SingleValueOperation, Mod); + + implement_unary_arithmetic_operation!(round, SingleValueOperation, Round); + implement_unary_arithmetic_operation!(ceil, SingleValueOperation, Ceil); + implement_unary_arithmetic_operation!(floor, SingleValueOperation, Floor); + implement_unary_arithmetic_operation!(abs, SingleValueOperation, Abs); + implement_unary_arithmetic_operation!(sqrt, SingleValueOperation, Sqrt); + implement_unary_arithmetic_operation!(trim, SingleValueOperation, Trim); + implement_unary_arithmetic_operation!(trim_start, SingleValueOperation, TrimStart); + implement_unary_arithmetic_operation!(trim_end, SingleValueOperation, TrimEnd); + implement_unary_arithmetic_operation!(lowercase, SingleValueOperation, Lowercase); + implement_unary_arithmetic_operation!(uppercase, SingleValueOperation, Uppercase); + + pub fn slice(&mut self, start: usize, end: usize) { + self.operations + .push(SingleValueOperation::Slice(start..end)); + } + + implement_assertion_operation!(is_string, SingleValueOperation::IsString); + implement_assertion_operation!(is_int, SingleValueOperation::IsInt); + implement_assertion_operation!(is_float, SingleValueOperation::IsFloat); + implement_assertion_operation!(is_bool, SingleValueOperation::IsBool); + implement_assertion_operation!(is_datetime, SingleValueOperation::IsDateTime); + implement_assertion_operation!(is_null, SingleValueOperation::IsNull); + + pub fn eiter_or(&mut self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + let mut either_operand = + Wrapper::::new(self.context.clone(), self.kind.clone()); + let mut or_operand = + Wrapper::::new(self.context.clone(), self.kind.clone()); + + either_query(&mut either_operand); + or_query(&mut or_operand); + + self.operations.push(SingleValueOperation::EitherOr { + either: either_operand, + or: or_operand, + }); + } +} + +impl Wrapper { + pub(crate) fn new(context: MultipleValuesOperand, kind: SingleKind) -> Self { + SingleValueOperand::new(context, kind).into() + } + + pub(crate) fn evaluate( + &self, + medrecord: &MedRecord, + value: MedRecordValue, + ) -> MedRecordResult> { + self.0.read_or_panic().evaluate(medrecord, value) + } + + implement_wrapper_operand_with_argument!(greater_than, impl Into); + implement_wrapper_operand_with_argument!( + greater_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(less_than, impl Into); + implement_wrapper_operand_with_argument!( + less_than_or_equal_to, + impl Into + ); + implement_wrapper_operand_with_argument!(equal_to, impl Into); + implement_wrapper_operand_with_argument!(not_equal_to, impl Into); + implement_wrapper_operand_with_argument!(starts_with, impl Into); + implement_wrapper_operand_with_argument!(ends_with, impl Into); + implement_wrapper_operand_with_argument!(contains, impl Into); + implement_wrapper_operand_with_argument!(is_in, impl Into); + implement_wrapper_operand_with_argument!(is_not_in, impl Into); + implement_wrapper_operand_with_argument!(add, impl Into); + implement_wrapper_operand_with_argument!(sub, impl Into); + implement_wrapper_operand_with_argument!(mul, impl Into); + implement_wrapper_operand_with_argument!(div, impl Into); + implement_wrapper_operand_with_argument!(pow, impl Into); + implement_wrapper_operand_with_argument!(r#mod, impl Into); + + implement_wrapper_operand!(round); + implement_wrapper_operand!(ceil); + implement_wrapper_operand!(floor); + implement_wrapper_operand!(abs); + implement_wrapper_operand!(sqrt); + implement_wrapper_operand!(trim); + implement_wrapper_operand!(trim_start); + implement_wrapper_operand!(trim_end); + implement_wrapper_operand!(lowercase); + implement_wrapper_operand!(uppercase); + + pub fn slice(&self, start: usize, end: usize) { + self.0.write_or_panic().slice(start, end) + } + + implement_wrapper_operand!(is_string); + implement_wrapper_operand!(is_int); + implement_wrapper_operand!(is_float); + implement_wrapper_operand!(is_bool); + implement_wrapper_operand!(is_datetime); + implement_wrapper_operand!(is_null); + + pub fn either_or(&self, either_query: EQ, or_query: OQ) + where + EQ: FnOnce(&mut Wrapper), + OQ: FnOnce(&mut Wrapper), + { + self.0.write_or_panic().eiter_or(either_query, or_query); + } +} diff --git a/crates/medmodels-core/src/medrecord/querying/values/operation.rs b/crates/medmodels-core/src/medrecord/querying/values/operation.rs new file mode 100644 index 00000000..2a559d99 --- /dev/null +++ b/crates/medmodels-core/src/medrecord/querying/values/operation.rs @@ -0,0 +1,934 @@ +use super::{ + operand::{ + MultipleValuesComparisonOperand, MultipleValuesOperand, SingleValueComparisonOperand, + SingleValueOperand, + }, + BinaryArithmeticKind, MultipleComparisonKind, SingleComparisonKind, SingleKind, + UnaryArithmeticKind, +}; +use crate::{ + errors::{MedRecordError, MedRecordResult}, + medrecord::{ + datatypes::{ + Abs, Ceil, Contains, EndsWith, Floor, Lowercase, Mod, Pow, Round, Slice, Sqrt, + StartsWith, Trim, TrimEnd, TrimStart, Uppercase, + }, + querying::{ + traits::{DeepClone, ReadWriteOrPanic}, + BoxedIterator, + }, + DataType, MedRecordValue, Wrapper, + }, + MedRecord, +}; +use itertools::Itertools; +use std::{ + cmp::Ordering, + hash::Hash, + ops::{Add, Div, Mul, Range, Sub}, +}; + +macro_rules! get_single_operand_value { + ($kind:ident, $values:expr) => { + match $kind { + SingleKind::Max => MultipleValuesOperation::get_max($values)?.1, + SingleKind::Min => MultipleValuesOperation::get_min($values)?.1, + SingleKind::Mean => MultipleValuesOperation::get_mean($values)?, + SingleKind::Median => MultipleValuesOperation::get_median($values)?, + SingleKind::Mode => MultipleValuesOperation::get_mode($values)?, + SingleKind::Std => MultipleValuesOperation::get_std($values)?, + SingleKind::Var => MultipleValuesOperation::get_var($values)?, + SingleKind::Count => MultipleValuesOperation::get_count($values), + SingleKind::Sum => MultipleValuesOperation::get_sum($values)?, + SingleKind::First => MultipleValuesOperation::get_first($values)?, + SingleKind::Last => MultipleValuesOperation::get_last($values)?, + } + }; +} + +macro_rules! get_single_value_comparison_operand_value { + ($operand:ident, $medrecord:ident) => { + match $operand { + SingleValueComparisonOperand::Operand(operand) => { + let context = &operand.context.context; + let attribute = operand.context.attribute.clone(); + let kind = &operand.kind; + + let comparison_values = context + .get_values($medrecord, attribute)? + .map(|value| (&0, value)); + + let comparison_value = get_single_operand_value!(kind, comparison_values); + + comparison_value + } + SingleValueComparisonOperand::Value(value) => value.clone(), + } + }; +} + +macro_rules! get_median { + ($values:ident, $variant:ident) => { + if $values.len() % 2 == 0 { + let middle = $values.len() / 2; + + let first = $values.get(middle - 1).unwrap(); + let second = $values.get(middle).unwrap(); + + let first = MedRecordValue::$variant(*first); + let second = MedRecordValue::$variant(*second); + + first.add(second).unwrap().div(MedRecordValue::Int(2)) + } else { + let middle = $values.len() / 2; + + Ok(MedRecordValue::$variant( + $values.get(middle).unwrap().clone(), + )) + } + }; +} + +#[derive(Debug, Clone)] +pub enum MultipleValuesOperation { + ValueOperation { + operand: Wrapper, + }, + SingleValueComparisonOperation { + operand: SingleValueComparisonOperand, + kind: SingleComparisonKind, + }, + MultipleValuesComparisonOperation { + operand: MultipleValuesComparisonOperand, + kind: MultipleComparisonKind, + }, + BinaryArithmeticOpration { + operand: SingleValueComparisonOperand, + kind: BinaryArithmeticKind, + }, + UnaryArithmeticOperation { + kind: UnaryArithmeticKind, + }, + + Slice(Range), + + IsString, + IsInt, + IsFloat, + IsBool, + IsDateTime, + IsNull, + + IsMax, + IsMin, + + EitherOr { + either: Wrapper, + or: Wrapper, + }, +} + +impl DeepClone for MultipleValuesOperation { + fn deep_clone(&self) -> Self { + match self { + Self::ValueOperation { operand } => Self::ValueOperation { + operand: operand.deep_clone(), + }, + Self::SingleValueComparisonOperation { operand, kind } => { + Self::SingleValueComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::MultipleValuesComparisonOperation { operand, kind } => { + Self::MultipleValuesComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::BinaryArithmeticOpration { operand, kind } => Self::BinaryArithmeticOpration { + operand: operand.deep_clone(), + kind: kind.clone(), + }, + Self::UnaryArithmeticOperation { kind } => { + Self::UnaryArithmeticOperation { kind: kind.clone() } + } + Self::Slice(range) => Self::Slice(range.clone()), + Self::IsString => Self::IsString, + Self::IsInt => Self::IsInt, + Self::IsFloat => Self::IsFloat, + Self::IsBool => Self::IsBool, + Self::IsDateTime => Self::IsDateTime, + Self::IsNull => Self::IsNull, + Self::IsMax => Self::IsMax, + Self::IsMin => Self::IsMin, + Self::EitherOr { either, or } => Self::EitherOr { + either: either.deep_clone(), + or: or.deep_clone(), + }, + } + } +} + +impl MultipleValuesOperation { + pub(crate) fn evaluate<'a, T: Eq + Hash>( + &self, + medrecord: &'a MedRecord, + values: impl Iterator + 'a, + ) -> MedRecordResult> { + match self { + Self::ValueOperation { operand } => { + Self::evaluate_value_operation(medrecord, values, operand) + } + Self::SingleValueComparisonOperation { operand, kind } => { + Self::evaluate_single_value_comparison_operation(medrecord, values, operand, kind) + } + Self::MultipleValuesComparisonOperation { operand, kind } => { + Self::evaluate_multiple_values_comparison_operation( + medrecord, values, operand, kind, + ) + } + Self::BinaryArithmeticOpration { operand, kind } => Ok(Box::new( + Self::evaluate_binary_arithmetic_operation(medrecord, values, operand, kind)?, + )), + Self::UnaryArithmeticOperation { kind } => Ok(Box::new( + Self::evaluate_unary_arithmetic_operation(values, kind.clone()), + )), + Self::Slice(range) => Ok(Box::new(Self::evaluate_slice(values, range.clone()))), + Self::IsString => { + Ok(Box::new(values.filter(|(_, value)| { + matches!(value, MedRecordValue::String(_)) + }))) + } + Self::IsInt => { + Ok(Box::new(values.filter(|(_, value)| { + matches!(value, MedRecordValue::Int(_)) + }))) + } + Self::IsFloat => { + Ok(Box::new(values.filter(|(_, value)| { + matches!(value, MedRecordValue::Float(_)) + }))) + } + Self::IsBool => { + Ok(Box::new(values.filter(|(_, value)| { + matches!(value, MedRecordValue::Bool(_)) + }))) + } + Self::IsDateTime => { + Ok(Box::new(values.filter(|(_, value)| { + matches!(value, MedRecordValue::DateTime(_)) + }))) + } + Self::IsNull => { + Ok(Box::new(values.filter(|(_, value)| { + matches!(value, MedRecordValue::Null) + }))) + } + Self::IsMax => { + let max_value = Self::get_max(values)?; + + Ok(Box::new(std::iter::once(max_value))) + } + Self::IsMin => { + let min_value = Self::get_min(values)?; + + Ok(Box::new(std::iter::once(min_value))) + } + Self::EitherOr { either, or } => { + Self::evaluate_either_or(medrecord, values, either, or) + } + } + } + + #[inline] + pub(crate) fn get_max<'a, T>( + mut values: impl Iterator, + ) -> MedRecordResult<(&'a T, MedRecordValue)> { + let max_value = values.next().ok_or(MedRecordError::QueryError( + "No values to compare".to_string(), + ))?; + + values.try_fold(max_value, |max_value, value| { + match value.1.partial_cmp(&max_value.1) { + Some(Ordering::Greater) => Ok(value), + None => { + let first_dtype = DataType::from(value.1); + let second_dtype = DataType::from(max_value.1); + + Err(MedRecordError::QueryError(format!( + "Cannot compare values of data types {} and {}. Consider narrowing down the values using .is_string(), .is_int(), .is_float(), .is_bool() or .is_datetime()", + first_dtype, second_dtype + ))) + } + _ => Ok(max_value), + } + }) + } + + #[inline] + pub(crate) fn get_min<'a, T>( + mut values: impl Iterator, + ) -> MedRecordResult<(&'a T, MedRecordValue)> { + let min_value = values.next().ok_or(MedRecordError::QueryError( + "No values to compare".to_string(), + ))?; + + values.try_fold(min_value, |min_value, value| { + match value.1.partial_cmp(&min_value.1) { + Some(Ordering::Less) => Ok(value), + None => { + let first_dtype = DataType::from(value.1); + let second_dtype = DataType::from(min_value.1); + + Err(MedRecordError::QueryError(format!( + "Cannot compare values of data types {} and {}. Consider narrowing down the values using .is_string(), .is_int(), .is_float(), .is_bool() or .is_datetime()", + first_dtype, second_dtype + ))) + } + _ => Ok(min_value), + } + }) + } + + #[inline] + pub(crate) fn get_mean<'a, T: 'a>( + mut values: impl Iterator, + ) -> MedRecordResult { + let first_value = values.next().ok_or(MedRecordError::QueryError( + "No values to compare".to_string(), + ))?; + + let (sum, count) = values.try_fold((first_value.1, 1), |(sum, count), (_, value)| { + let first_dtype = DataType::from(&sum); + let second_dtype = DataType::from(&value); + + match sum.add(value) { + Ok(sum) => Ok((sum, count + 1)), + Err(_) => Err(MedRecordError::QueryError(format!( + "Cannot add values of data types {} and {}. Consider narrowing down the values using .is_int(), .is_float() or .is_datetime()", + first_dtype, second_dtype + ))), + } + })?; + + sum.div(MedRecordValue::Int(count as i64)) + } + + // TODO: This is a temporary solution. It should be optimized. + #[inline] + pub(crate) fn get_median<'a, T: 'a>( + mut values: impl Iterator, + ) -> MedRecordResult { + let first_value = values.next().ok_or(MedRecordError::QueryError( + "No values to compare".to_string(), + ))?; + + let first_data_type = DataType::from(&first_value.1); + + match first_value.1 { + MedRecordValue::Int(value) => { + let mut values = values.map(|(_, value)| { + let data_type = DataType::from(&value); + + match value { + MedRecordValue::Int(value) => Ok(value as f64), + MedRecordValue::Float(value) => Ok(value), + _ => Err(MedRecordError::QueryError(format!( + "Cannot calculate median of mixed data types {} and {}. Consider narrowing down the values using .is_int(), .is_float() or .is_datetime()", + first_data_type, data_type + ))), + } + }).collect::>>()?; + values.push(value as f64); + values.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap()); + + get_median!(values, Float) + } + MedRecordValue::Float(value) => { + let mut values = values.map(|(_, value)| { + let data_type = DataType::from(&value); + + match value { + MedRecordValue::Int(value) => Ok(value as f64), + MedRecordValue::Float(value) => Ok(value), + _ => Err(MedRecordError::QueryError(format!( + "Cannot calculate median of mixed data types {} and {}. Consider narrowing down the values using .is_int(), .is_float() or .is_datetime()", + first_data_type, data_type + ))), + } + }).collect::>>()?; + values.push(value); + values.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap()); + + get_median!(values, Float) + } + MedRecordValue::DateTime(value) => { + let mut values = values.map(|(_, value)| { + let data_type = DataType::from(&value); + + match value { + MedRecordValue::DateTime(naive_date_time) => Ok(naive_date_time), + _ => Err(MedRecordError::QueryError(format!( + "Cannot calculate median of mixed data types {} and {}. Consider narrowing down the values using .is_int(), .is_float() or .is_datetime()", + first_data_type, data_type + ))), + } + }).collect::>>()?; + values.push(value); + values.sort(); + + get_median!(values, DateTime) + } + _ => Err(MedRecordError::QueryError(format!( + "Cannot calculate median of data type {}", + first_data_type + )))?, + } + } + + // TODO: This is a temporary solution. It should be optimized. + #[inline] + pub(crate) fn get_mode<'a, T: 'a>( + values: impl Iterator, + ) -> MedRecordResult { + let values = values.map(|(_, value)| value).collect::>(); + + let most_common_value = values + .first() + .ok_or(MedRecordError::QueryError( + "No values to compare".to_string(), + ))? + .clone(); + let most_common_count = values + .iter() + .filter(|value| **value == most_common_value) + .count(); + + let (_, most_common_value) = values.clone().into_iter().fold( + (most_common_count, most_common_value), + |acc, value| { + let count = values.iter().filter(|v| **v == value).count(); + + if count > acc.0 { + (count, value) + } else { + acc + } + }, + ); + + Ok(most_common_value) + } + + #[inline] + // 👀 + pub(crate) fn get_std<'a, T: 'a>( + values: impl Iterator, + ) -> MedRecordResult { + let variance = Self::get_var(values)?; + + let MedRecordValue::Float(variance) = variance else { + unreachable!() + }; + + Ok(MedRecordValue::Float(variance.sqrt())) + } + + // TODO: This is a temporary solution. It should be optimized. + #[inline] + pub(crate) fn get_var<'a, T: 'a>( + values: impl Iterator, + ) -> MedRecordResult { + let values = values.collect::>(); + + let mean = Self::get_mean(values.clone().into_iter())?; + + let MedRecordValue::Float(mean) = mean else { + let data_type = DataType::from(mean); + + return Err(MedRecordError::QueryError( + format!("Cannot calculate variance of data type {}. Consider narrowing down the values using .is_int() or .is_float()", data_type), + )); + }; + + let values = values + .into_iter() + .map(|value| { + let data_type = DataType::from(&value.1); + + match value.1 { + MedRecordValue::Int(value) => Ok(value as f64), + MedRecordValue::Float(value) => Ok(value), + _ => Err(MedRecordError::QueryError( + format!("Cannot calculate variance of data type {}. Consider narrowing down the values using .is_int() or .is_float()", data_type), + )), + }}) + .collect::>>()?; + + let values_length = values.len(); + + let variance = values + .into_iter() + .map(|value| (value - mean).powi(2)) + .sum::() + / values_length as f64; + + Ok(MedRecordValue::Float(variance)) + } + + #[inline] + pub(crate) fn get_count<'a, T: 'a>( + values: impl Iterator, + ) -> MedRecordValue { + MedRecordValue::Int(values.count() as i64) + } + + #[inline] + // 🥊💥 + pub(crate) fn get_sum<'a, T: 'a>( + mut values: impl Iterator, + ) -> MedRecordResult { + let first_value = values.next().ok_or(MedRecordError::QueryError( + "No values to compare".to_string(), + ))?; + + values.try_fold(first_value.1, |sum, (_, value)| { + let first_dtype = DataType::from(&sum); + let second_dtype = DataType::from(&value); + + sum.add(value).map_err(|_| { + MedRecordError::QueryError(format!( + "Cannot add values of data types {} and {}. Consider narrowing down the values using .is_string(), .is_int(), .is_float(), .is_bool() or .is_datetime()", + first_dtype, second_dtype + )) + }) + }) + } + + #[inline] + pub(crate) fn get_first<'a, T: 'a>( + mut values: impl Iterator, + ) -> MedRecordResult { + values + .next() + .ok_or(MedRecordError::QueryError( + "No values to compare".to_string(), + )) + .map(|(_, value)| value) + } + + #[inline] + // 🥊💥 + pub(crate) fn get_last<'a, T: 'a>( + values: impl Iterator, + ) -> MedRecordResult { + values + .last() + .ok_or(MedRecordError::QueryError( + "No values to compare".to_string(), + )) + .map(|(_, value)| value) + } + + #[inline] + fn evaluate_value_operation<'a, T>( + medrecord: &'a MedRecord, + values: impl Iterator, + operand: &Wrapper, + ) -> MedRecordResult> { + let kind = &operand.0.read_or_panic().kind; + + let values = values.collect::>(); + + let value = get_single_operand_value!(kind, values.clone().into_iter()); + + Ok(match operand.evaluate(medrecord, value)? { + Some(_) => Box::new(values.into_iter()), + None => Box::new(std::iter::empty()), + }) + } + + #[inline] + fn evaluate_single_value_comparison_operation<'a, T>( + medrecord: &'a MedRecord, + values: impl Iterator + 'a, + comparison_operand: &SingleValueComparisonOperand, + kind: &SingleComparisonKind, + ) -> MedRecordResult> { + let comparison_value = + get_single_value_comparison_operand_value!(comparison_operand, medrecord); + + match kind { + SingleComparisonKind::GreaterThan => Ok(Box::new( + values.filter(move |(_, value)| value > &comparison_value), + )), + SingleComparisonKind::GreaterThanOrEqualTo => Ok(Box::new( + values.filter(move |(_, value)| value >= &comparison_value), + )), + SingleComparisonKind::LessThan => Ok(Box::new( + values.filter(move |(_, value)| value < &comparison_value), + )), + SingleComparisonKind::LessThanOrEqualTo => Ok(Box::new( + values.filter(move |(_, value)| value <= &comparison_value), + )), + SingleComparisonKind::EqualTo => Ok(Box::new( + values.filter(move |(_, value)| value == &comparison_value), + )), + SingleComparisonKind::NotEqualTo => Ok(Box::new( + values.filter(move |(_, value)| value != &comparison_value), + )), + SingleComparisonKind::StartsWith => { + Ok(Box::new(values.filter(move |(_, value)| { + value.starts_with(&comparison_value) + }))) + } + SingleComparisonKind::EndsWith => { + Ok(Box::new(values.filter(move |(_, value)| { + value.ends_with(&comparison_value) + }))) + } + SingleComparisonKind::Contains => { + Ok(Box::new(values.filter(move |(_, value)| { + value.contains(&comparison_value) + }))) + } + } + } + + #[inline] + fn evaluate_multiple_values_comparison_operation<'a, T>( + medrecord: &'a MedRecord, + values: impl Iterator + 'a, + comparison_operand: &MultipleValuesComparisonOperand, + kind: &MultipleComparisonKind, + ) -> MedRecordResult> { + let comparison_values = match comparison_operand { + MultipleValuesComparisonOperand::Operand(operand) => { + let context = &operand.context; + let attribute = operand.attribute.clone(); + + context + .get_values(medrecord, attribute)? + .collect::>() + } + MultipleValuesComparisonOperand::Values(values) => values.clone(), + }; + + match kind { + MultipleComparisonKind::IsIn => { + Ok(Box::new(values.filter(move |(_, value)| { + comparison_values.contains(value) + }))) + } + MultipleComparisonKind::IsNotIn => { + Ok(Box::new(values.filter(move |(_, value)| { + !comparison_values.contains(value) + }))) + } + } + } + + #[inline] + fn evaluate_binary_arithmetic_operation<'a, T: 'a>( + medrecord: &'a MedRecord, + values: impl Iterator, + operand: &SingleValueComparisonOperand, + kind: &BinaryArithmeticKind, + ) -> MedRecordResult> { + let arithmetic_value = get_single_value_comparison_operand_value!(operand, medrecord); + + let values = values + .map(move |(t, value)| { + match kind { + BinaryArithmeticKind::Add => value.add(arithmetic_value.clone()), + BinaryArithmeticKind::Sub => value.sub(arithmetic_value.clone()), + BinaryArithmeticKind::Mul => { + value.clone().mul(arithmetic_value.clone()) + } + BinaryArithmeticKind::Div => { + value.clone().div(arithmetic_value.clone()) + } + BinaryArithmeticKind::Pow => { + value.clone().pow(arithmetic_value.clone()) + } + BinaryArithmeticKind::Mod => { + value.clone().r#mod(arithmetic_value.clone()) + } + } + .map_err(|_| { + MedRecordError::QueryError(format!( + "Failed arithmetic operation {}. Consider narrowing down the values using .is_int() or .is_float()", + kind, + )) + }).map(|result| (t, result)) + }); + + // TODO: This is a temporary solution. It should be optimized. + Ok(values.collect::>>()?.into_iter()) + } + + #[inline] + fn evaluate_unary_arithmetic_operation<'a, T: 'a>( + values: impl Iterator, + kind: UnaryArithmeticKind, + ) -> impl Iterator { + values.map(move |(t, value)| { + let value = match kind { + UnaryArithmeticKind::Round => value.round(), + UnaryArithmeticKind::Ceil => value.ceil(), + UnaryArithmeticKind::Floor => value.floor(), + UnaryArithmeticKind::Abs => value.abs(), + UnaryArithmeticKind::Sqrt => value.sqrt(), + UnaryArithmeticKind::Trim => value.trim(), + UnaryArithmeticKind::TrimStart => value.trim_start(), + UnaryArithmeticKind::TrimEnd => value.trim_end(), + UnaryArithmeticKind::Lowercase => value.lowercase(), + UnaryArithmeticKind::Uppercase => value.uppercase(), + }; + (t, value) + }) + } + + #[inline] + fn evaluate_slice<'a, T: 'a>( + values: impl Iterator, + range: Range, + ) -> impl Iterator { + values.map(move |(t, value)| (t, value.slice(range.clone()))) + } + + #[inline] + fn evaluate_either_or<'a, T: 'a + Eq + Hash>( + medrecord: &'a MedRecord, + values: impl Iterator, + either: &Wrapper, + or: &Wrapper, + ) -> MedRecordResult> { + let values = values.collect::>(); + + let either_values = either.evaluate(medrecord, values.clone().into_iter())?; + let or_values = or.evaluate(medrecord, values.into_iter())?; + + Ok(Box::new( + either_values.chain(or_values).unique_by(|value| value.0), + )) + } +} + +#[derive(Debug, Clone)] +pub enum SingleValueOperation { + SingleValueComparisonOperation { + operand: SingleValueComparisonOperand, + kind: SingleComparisonKind, + }, + MultipleValuesComparisonOperation { + operand: MultipleValuesComparisonOperand, + kind: MultipleComparisonKind, + }, + BinaryArithmeticOpration { + operand: SingleValueComparisonOperand, + kind: BinaryArithmeticKind, + }, + UnaryArithmeticOperation { + kind: UnaryArithmeticKind, + }, + + Slice(Range), + + IsString, + IsInt, + IsFloat, + IsBool, + IsDateTime, + IsNull, + + EitherOr { + either: Wrapper, + or: Wrapper, + }, +} + +impl DeepClone for SingleValueOperation { + fn deep_clone(&self) -> Self { + match self { + Self::SingleValueComparisonOperation { operand, kind } => { + Self::SingleValueComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::MultipleValuesComparisonOperation { operand, kind } => { + Self::MultipleValuesComparisonOperation { + operand: operand.deep_clone(), + kind: kind.clone(), + } + } + Self::BinaryArithmeticOpration { operand, kind } => Self::BinaryArithmeticOpration { + operand: operand.deep_clone(), + kind: kind.clone(), + }, + Self::UnaryArithmeticOperation { kind } => { + Self::UnaryArithmeticOperation { kind: kind.clone() } + } + Self::Slice(range) => Self::Slice(range.clone()), + Self::IsString => Self::IsString, + Self::IsInt => Self::IsInt, + Self::IsFloat => Self::IsFloat, + Self::IsBool => Self::IsBool, + Self::IsDateTime => Self::IsDateTime, + Self::IsNull => Self::IsNull, + Self::EitherOr { either, or } => Self::EitherOr { + either: either.deep_clone(), + or: or.deep_clone(), + }, + } + } +} + +impl SingleValueOperation { + pub(crate) fn evaluate( + &self, + medrecord: &MedRecord, + value: MedRecordValue, + ) -> MedRecordResult> { + match self { + Self::SingleValueComparisonOperation { operand, kind } => { + Self::evaluate_single_value_comparison_operation(medrecord, value, operand, kind) + } + Self::MultipleValuesComparisonOperation { operand, kind } => { + Self::evaluate_multiple_values_comparison_operation(medrecord, value, operand, kind) + } + Self::BinaryArithmeticOpration { operand, kind } => { + Self::evaluate_binary_arithmetic_operation(medrecord, value, operand, kind) + } + Self::UnaryArithmeticOperation { kind } => Ok(Some(match kind { + UnaryArithmeticKind::Round => value.round(), + UnaryArithmeticKind::Ceil => value.ceil(), + UnaryArithmeticKind::Floor => value.floor(), + UnaryArithmeticKind::Abs => value.abs(), + UnaryArithmeticKind::Sqrt => value.sqrt(), + UnaryArithmeticKind::Trim => value.trim(), + UnaryArithmeticKind::TrimStart => value.trim_start(), + UnaryArithmeticKind::TrimEnd => value.trim_end(), + UnaryArithmeticKind::Lowercase => value.lowercase(), + UnaryArithmeticKind::Uppercase => value.uppercase(), + })), + Self::Slice(range) => Ok(Some(value.slice(range.clone()))), + Self::IsString => Ok(match value { + MedRecordValue::String(_) => Some(value), + _ => None, + }), + Self::IsInt => Ok(match value { + MedRecordValue::Int(_) => Some(value), + _ => None, + }), + Self::IsFloat => Ok(match value { + MedRecordValue::Float(_) => Some(value), + _ => None, + }), + Self::IsBool => Ok(match value { + MedRecordValue::Bool(_) => Some(value), + _ => None, + }), + Self::IsDateTime => Ok(match value { + MedRecordValue::DateTime(_) => Some(value), + _ => None, + }), + Self::IsNull => Ok(match value { + MedRecordValue::Null => Some(value), + _ => None, + }), + Self::EitherOr { either, or } => Self::evaluate_either_or(medrecord, value, either, or), + } + } + + #[inline] + fn evaluate_single_value_comparison_operation( + medrecord: &MedRecord, + value: MedRecordValue, + comparison_operand: &SingleValueComparisonOperand, + kind: &SingleComparisonKind, + ) -> MedRecordResult> { + let comparison_value = + get_single_value_comparison_operand_value!(comparison_operand, medrecord); + + let comparison_result = match kind { + SingleComparisonKind::GreaterThan => value > comparison_value, + SingleComparisonKind::GreaterThanOrEqualTo => value >= comparison_value, + SingleComparisonKind::LessThan => value < comparison_value, + SingleComparisonKind::LessThanOrEqualTo => value <= comparison_value, + SingleComparisonKind::EqualTo => value == comparison_value, + SingleComparisonKind::NotEqualTo => value != comparison_value, + SingleComparisonKind::StartsWith => value.starts_with(&comparison_value), + SingleComparisonKind::EndsWith => value.ends_with(&comparison_value), + SingleComparisonKind::Contains => value.contains(&comparison_value), + }; + + Ok(if comparison_result { Some(value) } else { None }) + } + + #[inline] + fn evaluate_multiple_values_comparison_operation( + medrecord: &MedRecord, + value: MedRecordValue, + comparison_operand: &MultipleValuesComparisonOperand, + kind: &MultipleComparisonKind, + ) -> MedRecordResult> { + let comparison_values = match comparison_operand { + MultipleValuesComparisonOperand::Operand(operand) => { + let context = &operand.context; + let attribute = operand.attribute.clone(); + + context + .get_values(medrecord, attribute)? + .collect::>() + } + MultipleValuesComparisonOperand::Values(values) => values.clone(), + }; + + let comparison_result = match kind { + MultipleComparisonKind::IsIn => comparison_values.contains(&value), + MultipleComparisonKind::IsNotIn => !comparison_values.contains(&value), + }; + + Ok(if comparison_result { Some(value) } else { None }) + } + + #[inline] + fn evaluate_binary_arithmetic_operation( + medrecord: &MedRecord, + value: MedRecordValue, + operand: &SingleValueComparisonOperand, + kind: &BinaryArithmeticKind, + ) -> MedRecordResult> { + let arithmetic_value = get_single_value_comparison_operand_value!(operand, medrecord); + + match kind { + BinaryArithmeticKind::Add => value.add(arithmetic_value), + BinaryArithmeticKind::Sub => value.sub(arithmetic_value), + BinaryArithmeticKind::Mul => value.mul(arithmetic_value), + BinaryArithmeticKind::Div => value.div(arithmetic_value), + BinaryArithmeticKind::Pow => value.pow(arithmetic_value), + BinaryArithmeticKind::Mod => value.r#mod(arithmetic_value), + } + .map(Some) + } + + #[inline] + fn evaluate_either_or( + medrecord: &MedRecord, + value: MedRecordValue, + either: &Wrapper, + or: &Wrapper, + ) -> MedRecordResult> { + let either_result = either.evaluate(medrecord, value.clone())?; + let or_result = or.evaluate(medrecord, value)?; + + match (either_result, or_result) { + (Some(either_result), _) => Ok(Some(either_result)), + (None, Some(or_result)) => Ok(Some(or_result)), + _ => Ok(None), + } + } +} diff --git a/crates/medmodels-core/src/medrecord/querying/wrapper.rs b/crates/medmodels-core/src/medrecord/querying/wrapper.rs new file mode 100644 index 00000000..a5d338bc --- /dev/null +++ b/crates/medmodels-core/src/medrecord/querying/wrapper.rs @@ -0,0 +1,45 @@ +use super::traits::{DeepClone, ReadWriteOrPanic}; +use std::sync::{Arc, RwLock}; + +#[repr(transparent)] +#[derive(Debug, Clone)] +pub struct Wrapper(pub(crate) Arc>); + +impl From for Wrapper { + fn from(value: T) -> Self { + Self(Arc::new(RwLock::new(value))) + } +} + +impl DeepClone for Wrapper +where + T: DeepClone, +{ + fn deep_clone(&self) -> Self { + self.0.read_or_panic().deep_clone().into() + } +} + +#[derive(Debug, Clone)] +pub enum CardinalityWrapper { + Single(T), + Multiple(Vec), +} + +impl From for CardinalityWrapper { + fn from(value: T) -> Self { + Self::Single(value) + } +} + +impl From> for CardinalityWrapper { + fn from(value: Vec) -> Self { + Self::Multiple(value) + } +} + +impl From<[T; N]> for CardinalityWrapper { + fn from(value: [T; N]) -> Self { + Self::Multiple(value.to_vec()) + } +} diff --git a/crates/medmodels-core/src/medrecord/schema.rs b/crates/medmodels-core/src/medrecord/schema.rs index 2bcdd562..8015870e 100644 --- a/crates/medmodels-core/src/medrecord/schema.rs +++ b/crates/medmodels-core/src/medrecord/schema.rs @@ -1,5 +1,3 @@ -#![allow(dead_code)] - use super::{Attributes, EdgeIndex, NodeIndex}; use crate::{ errors::GraphError, diff --git a/rustmodels/Cargo.toml b/rustmodels/Cargo.toml index a6640f90..70922829 100644 --- a/rustmodels/Cargo.toml +++ b/rustmodels/Cargo.toml @@ -11,7 +11,7 @@ crate-type = ["cdylib"] medmodels-core = { workspace = true } medmodels-utils = { workspace = true } -pyo3 = { workspace = true } -pyo3-polars = { workspace = true } +pyo3 = { version = "0.21.2", features = ["chrono"] } +pyo3-polars = "0.14.0" polars = { workspace = true } chrono = { workspace = true } diff --git a/rustmodels/src/medrecord/mod.rs b/rustmodels/src/medrecord/mod.rs index eff9caf9..baa4f2f6 100644 --- a/rustmodels/src/medrecord/mod.rs +++ b/rustmodels/src/medrecord/mod.rs @@ -637,7 +637,7 @@ impl PyMedRecord { .map(|node_index| { let neighbors = self .0 - .neighbors(&node_index) + .neighbors_outgoing(&node_index) .map_err(PyMedRecordError::from)? .map(|neighbor| neighbor.clone().into()) .collect();