From eaecee09b5b0a9f90c191534ae4c2db2f456eb9d Mon Sep 17 00:00:00 2001 From: Eddy Oyieko <67474838+mobley-trent@users.noreply.github.com> Date: Thu, 15 Feb 2024 16:29:06 +0300 Subject: [PATCH 1/9] Initial commit --- datafusion/core/src/execution/context/mod.rs | 30 ++++++++++++++++++++ datafusion/execution/src/registry.rs | 27 ++++++++++++++++++ 2 files changed, 57 insertions(+) diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 0bc75720e7c8..724a5f146864 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -848,6 +848,18 @@ impl SessionContext { self.state.write().register_udwf(Arc::new(f)).ok(); } + pub fn deregister_udf(&self, name: &str) { + self.state.write().deregister_udf(name); + } + + pub fn deregister_udaf(&self, name: &str) { + self.state.write().deregister_udaf(name); + } + + pub fn deregister_udwf(&self, name: &str) { + self.state.write().deregister_udwf(name); + } + /// Creates a [`DataFrame`] for reading a data source. /// /// For more control such as reading multiple files, you can use @@ -1986,6 +1998,24 @@ impl FunctionRegistry for SessionState { fn register_udwf(&mut self, udwf: Arc) -> Result>> { Ok(self.window_functions.insert(udwf.name().into(), udwf)) } + + fn deregister_udf(&mut self, name: &str) -> Result>> { + let udf = self.scalar_functions.remove(name); + if let Some(udf) = &udf { + for alias in udf.aliases() { + self.scalar_functions.remove(alias); + } + } + Ok(udf) + } + + fn deregister_udaf(&mut self, name: &str) -> Result>> { + Ok(self.aggregate_functions.remove(name)) + } + + fn deregister_udwf(&mut self, name: &str) -> Result>> { + Ok(self.window_functions.remove(name)) + } } impl OptimizerConfig for SessionState { diff --git a/datafusion/execution/src/registry.rs b/datafusion/execution/src/registry.rs index a4bb8d1cc649..b5c999e36f63 100644 --- a/datafusion/execution/src/registry.rs +++ b/datafusion/execution/src/registry.rs @@ -68,6 +68,33 @@ pub trait FunctionRegistry { fn register_udwf(&mut self, _udaf: Arc) -> Result>> { not_impl_err!("Registering WindowUDF") } + + /// Deregisters a [`ScalarUDF`], returning the implementation that was + /// deregistered. + /// + /// Returns an error (the default) if the function can not be deregistered, + /// for example if the registry is read only. + fn deregister_udf(&mut self, _name: &str) -> Result>> { + not_impl_err!("Deregistering ScalarUDF") + } + + /// Deregisters a [`AggregateUDF`], returning the implementation that was + /// deregistered. + /// + /// Returns an error (the default) if the function can not be deregistered, + /// for example if the registry is read only. + fn deregister_udaf(&mut self, _name: &str) -> Result>> { + not_impl_err!("Deregistering AggregateUDF") + } + + /// Deregisters a [`WindowUDF`], returning the implementation that was + /// deregistered. + /// + /// Returns an error (the default) if the function can not be deregistered, + /// for example if the registry is read only. + fn deregister_udwf(&mut self, _name: &str) -> Result>> { + not_impl_err!("Deregistering WindowUDF") + } } /// Serializer and deserializer registry for extensions like [UserDefinedLogicalNode]. From 8631337a7b5df51bb89b9ecb6bf982f119562ec1 Mon Sep 17 00:00:00 2001 From: Eddy Oyieko <67474838+mobley-trent@users.noreply.github.com> Date: Mon, 19 Feb 2024 20:16:32 +0300 Subject: [PATCH 2/9] Updated mod.rs - Docstrings, Initial test --- datafusion/core/src/execution/context/mod.rs | 62 +++++++++++++++++++- 1 file changed, 59 insertions(+), 3 deletions(-) diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 724a5f146864..5d79092b4c50 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -848,16 +848,19 @@ impl SessionContext { self.state.write().register_udwf(Arc::new(f)).ok(); } + /// Deregisters a UDF within this context. pub fn deregister_udf(&self, name: &str) { - self.state.write().deregister_udf(name); + self.state.write().deregister_udf(name).ok(); } + /// Deregisters a UDAF within this context. pub fn deregister_udaf(&self, name: &str) { - self.state.write().deregister_udaf(name); + self.state.write().deregister_udaf(name).ok(); } + /// Deregisters a UDTF within this context. pub fn deregister_udwf(&self, name: &str) { - self.state.write().deregister_udwf(name); + self.state.write().deregister_udwf(name).ok(); } /// Creates a [`DataFrame`] for reading a data source. @@ -2184,6 +2187,16 @@ mod tests { use std::path::PathBuf; use std::sync::Weak; use tempfile::TempDir; + use datafusion_expr::ColumnarValue; + use datafusion_expr::expr_fn::create_udf; + use datafusion_common::cast::as_float64_array; + use crate::{ + arrow::{ + array::{ArrayRef, Float64Array}, + datatypes::DataType, + }, + logical_expr::Volatility, + }; #[tokio::test] async fn shared_memory_and_disk_manager() { @@ -2277,6 +2290,49 @@ mod tests { Ok(()) } + #[tokio::test] + async fn register_deregister_udf() -> Result<()> { + let pow = Arc::new(|args: &[ColumnarValue]| { + assert_eq!(args.len(), 2); + + let args = ColumnarValue::values_to_arrays(args)?; + + let base = as_float64_array(&args[0]).expect("cast failed"); + let exponent = as_float64_array(&args[1]).expect("cast failed"); + + assert_eq!(exponent.len(), base.len()); + + let array = base + .iter() + .zip(exponent.iter()) + .map(|(base, exponent)| { + match (base, exponent) { + (Some(base), Some(exponent)) => Some(base.powf(exponent)), + _ => None, + } + }) + .collect::(); + + Ok(ColumnarValue::from(Arc::new(array) as ArrayRef)) + }); + + let pow = create_udf( + "pow", + vec![DataType::Float64, DataType::Float64], + Arc::new(DataType::Float64), + Volatility::Immutable, + pow, + ); + + let ctx = SessionContext::new(); + ctx.register_udf(pow.clone()); + + assert!(ctx.deregister_udf("pow").is_some()); + assert!(ctx.deregister_udf("pow").is_none()); + + Ok(()) + } + #[tokio::test] async fn send_context_to_threads() -> Result<()> { // ensure SessionContexts can be used in a multi-threaded From ff655671183bacd0ade8efa2e3fee9bcb927faac Mon Sep 17 00:00:00 2001 From: Eddy Oyieko <67474838+mobley-trent@users.noreply.github.com> Date: Mon, 19 Feb 2024 22:10:26 +0300 Subject: [PATCH 3/9] Updated mod.rs - Fixed udf test --- datafusion/core/src/execution/context/mod.rs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 5d79092b4c50..8e693ee52d66 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -2327,8 +2327,11 @@ mod tests { let ctx = SessionContext::new(); ctx.register_udf(pow.clone()); - assert!(ctx.deregister_udf("pow").is_some()); - assert!(ctx.deregister_udf("pow").is_none()); + assert!(ctx.udfs().contains("pow")); + + ctx.deregister_udf("pow"); + + assert!(!ctx.udfs().contains("pow")); Ok(()) } From 2940a90b8cc940c01e6499da65be45db1db44738 Mon Sep 17 00:00:00 2001 From: Eddy Oyieko <67474838+mobley-trent@users.noreply.github.com> Date: Fri, 23 Feb 2024 08:41:39 +0300 Subject: [PATCH 4/9] Added udaf test, Updated udf test --- datafusion/core/src/execution/context/mod.rs | 140 +++++++++++++++---- 1 file changed, 112 insertions(+), 28 deletions(-) diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 8e693ee52d66..6feb3be9fe7a 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -2188,14 +2188,16 @@ mod tests { use std::sync::Weak; use tempfile::TempDir; use datafusion_expr::ColumnarValue; - use datafusion_expr::expr_fn::create_udf; - use datafusion_common::cast::as_float64_array; + use datafusion_expr::expr_fn::{create_udf, create_udaf}; + use datafusion_common::cast::as_int64_array; use crate::{ arrow::{ - array::{ArrayRef, Float64Array}, + array::{ArrayRef, Int64Array}, datatypes::DataType, }, logical_expr::Volatility, + physical_plan::Accumulator, + scalar::ScalarValue, }; #[tokio::test] @@ -2292,46 +2294,128 @@ mod tests { #[tokio::test] async fn register_deregister_udf() -> Result<()> { - let pow = Arc::new(|args: &[ColumnarValue]| { - assert_eq!(args.len(), 2); + let add = Arc::new(|args: &[ColumnarValue]| { let args = ColumnarValue::values_to_arrays(args)?; + let i64s = as_int64_array(&args[0])?; - let base = as_float64_array(&args[0]).expect("cast failed"); - let exponent = as_float64_array(&args[1]).expect("cast failed"); - - assert_eq!(exponent.len(), base.len()); - - let array = base + let array = i64s .iter() - .zip(exponent.iter()) - .map(|(base, exponent)| { - match (base, exponent) { - (Some(base), Some(exponent)) => Some(base.powf(exponent)), - _ => None, - } - }) - .collect::(); - + .map(|array_elem| array_elem.map(|value| value + 1)) + .collect::(); + Ok(ColumnarValue::from(Arc::new(array) as ArrayRef)) }); - let pow = create_udf( - "pow", - vec![DataType::Float64, DataType::Float64], + let udf = create_udf( + "add", + vec![DataType::Int64], + Arc::new(DataType::Int64), + Volatility::Immutable, + add, + ); + + let ctx = SessionContext::new(); + ctx.register_udf(udf.clone()); + + assert!(ctx.udfs().contains("add")); + + ctx.deregister_udf("add"); + + assert!(!ctx.udfs().contains("add")); + + Ok(()) + } + + #[tokio::test] + async fn register_deregister_udaf() -> Result<()> { + #[derive(Debug)] + struct GeometricMean { + n: u32, + prod: f64, + } + + impl GeometricMean { + pub fn new() -> Self { + GeometricMean { n: 0, prod: 1.0 } + } + } + + impl Accumulator for GeometricMean { + fn state(&mut self) -> Result> { + Ok(vec![ + ScalarValue::from(self.prod), + ScalarValue::from(self.n), + ]) + } + + fn evaluate(&mut self) -> Result { + let value = self.prod.powf(1.0 / self.n as f64); + Ok(ScalarValue::from(value)) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + let arr = &values[0]; + (0..arr.len()).try_for_each(|index| { + let v = ScalarValue::try_from_array(arr, index)?; + + if let ScalarValue::Float64(Some(value)) = v { + self.prod *= value; + self.n += 1; + } else { + unreachable!("") + } + Ok(()) + }) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + if states.is_empty() { + return Ok(()); + } + let arr = &states[0]; + (0..arr.len()).try_for_each(|index| { + let v = states + .iter() + .map(|array| ScalarValue::try_from_array(array, index)) + .collect::>>()?; + if let (ScalarValue::Float64(Some(prod)), ScalarValue::UInt32(Some(n))) = + (&v[0], &v[1]) + { + self.prod *= prod; + self.n += n; + } else { + unreachable!("") + } + Ok(()) + }) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } + } + + let udaf = create_udaf( + "geo_mean", + vec![DataType::Float64], Arc::new(DataType::Float64), Volatility::Immutable, - pow, + Arc::new(|_| Ok(Box::new(GeometricMean::new()))), + Arc::new(vec![DataType::Float64, DataType::UInt32]), ); let ctx = SessionContext::new(); - ctx.register_udf(pow.clone()); + ctx.register_udaf(udaf.clone()); - assert!(ctx.udfs().contains("pow")); + assert!(ctx.state().aggregate_functions.contains_key("geo_mean")); - ctx.deregister_udf("pow"); + ctx.deregister_udaf("geo_mean"); - assert!(!ctx.udfs().contains("pow")); + assert!(!ctx.state().aggregate_functions.contains_key("geo_mean")); Ok(()) } From e254313cd74842ed74a02d8f5d90a1e7972c41aa Mon Sep 17 00:00:00 2001 From: Eddy Oyieko <67474838+mobley-trent@users.noreply.github.com> Date: Tue, 27 Feb 2024 10:45:18 +0300 Subject: [PATCH 5/9] Added test for udwf --- datafusion/core/src/execution/context/mod.rs | 68 ++++++++++++++++++-- 1 file changed, 64 insertions(+), 4 deletions(-) diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 6feb3be9fe7a..2000364954b3 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -2187,13 +2187,13 @@ mod tests { use std::path::PathBuf; use std::sync::Weak; use tempfile::TempDir; - use datafusion_expr::ColumnarValue; - use datafusion_expr::expr_fn::{create_udf, create_udaf}; + use datafusion_expr::{ColumnarValue, PartitionEvaluator}; + use datafusion_expr::expr_fn::{create_udf, create_udaf, create_udwf}; use datafusion_common::cast::as_int64_array; use crate::{ arrow::{ - array::{ArrayRef, Int64Array}, - datatypes::DataType, + array::{ArrayRef, AsArray, Int64Array, Float64Array}, + datatypes::{DataType, Float64Type}, }, logical_expr::Volatility, physical_plan::Accumulator, @@ -2420,6 +2420,66 @@ mod tests { Ok(()) } + #[tokio::test] + async fn register_deregister_udwf() -> Result<()> { + #[derive(Clone, Debug)] + struct MyPartitionEvaluator {} + + impl MyPartitionEvaluator { + fn new() -> Self { + Self {} + } + } + + impl PartitionEvaluator for MyPartitionEvaluator { + fn uses_window_frame(&self) -> bool { + true + } + + fn evaluate( + &mut self, + values: &[ArrayRef], + range: &std::ops::Range, + ) -> Result { + let arr: &Float64Array = values[0].as_ref().as_primitive::(); + let range_len = range.end - range.start; + + let output = if range_len > 0 { + let sum: f64 = arr.values().iter().skip(range.start).take(range_len).sum(); + Some(sum / range_len as f64) + } else { + None + }; + + Ok(ScalarValue::Float64(output)) + } + } + + fn make_partition_evaluator() -> Result> { + Ok(Box::new(MyPartitionEvaluator::new())) + } + + let smooth_it = create_udwf( + "smooth_it", + DataType::Float64, + Arc::new(DataType::Float64), + Volatility::Immutable, + Arc::new(make_partition_evaluator), + ); + + let ctx = SessionContext::new(); + + ctx.register_udwf(smooth_it.clone()); + + assert!(ctx.state().window_functions.contains_key("smooth_it")); + + ctx.deregister_udwf("smooth_it"); + + assert!(!ctx.state().window_functions.contains_key("smooth_it")); + + Ok(()) + } + #[tokio::test] async fn send_context_to_threads() -> Result<()> { // ensure SessionContexts can be used in a multi-threaded From 4cd5e472504345e954936f53d9076db896b3e4d3 Mon Sep 17 00:00:00 2001 From: Eddy Oyieko <67474838+mobley-trent@users.noreply.github.com> Date: Tue, 27 Feb 2024 11:07:12 +0300 Subject: [PATCH 6/9] Linting with rustfmt --- datafusion/core/src/execution/context/mod.rs | 52 ++++++++++---------- 1 file changed, 27 insertions(+), 25 deletions(-) diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 2000364954b3..0fc1fd1b3db5 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -2181,24 +2181,24 @@ mod tests { use crate::test; use crate::test_util::{plan_and_collect, populate_csv_partitions}; use crate::variable::VarType; - use async_trait::async_trait; - use datafusion_expr::Expr; - use std::env; - use std::path::PathBuf; - use std::sync::Weak; - use tempfile::TempDir; - use datafusion_expr::{ColumnarValue, PartitionEvaluator}; - use datafusion_expr::expr_fn::{create_udf, create_udaf, create_udwf}; - use datafusion_common::cast::as_int64_array; use crate::{ arrow::{ - array::{ArrayRef, AsArray, Int64Array, Float64Array}, + array::{ArrayRef, AsArray, Float64Array, Int64Array}, datatypes::{DataType, Float64Type}, }, logical_expr::Volatility, physical_plan::Accumulator, scalar::ScalarValue, }; + use async_trait::async_trait; + use datafusion_common::cast::as_int64_array; + use datafusion_expr::expr_fn::{create_udaf, create_udf, create_udwf}; + use datafusion_expr::Expr; + use datafusion_expr::{ColumnarValue, PartitionEvaluator}; + use std::env; + use std::path::PathBuf; + use std::sync::Weak; + use tempfile::TempDir; #[tokio::test] async fn shared_memory_and_disk_manager() { @@ -2295,10 +2295,9 @@ mod tests { #[tokio::test] async fn register_deregister_udf() -> Result<()> { let add = Arc::new(|args: &[ColumnarValue]| { - let args = ColumnarValue::values_to_arrays(args)?; let i64s = as_int64_array(&args[0])?; - + let array = i64s .iter() .map(|array_elem| array_elem.map(|value| value + 1)) @@ -2334,13 +2333,13 @@ mod tests { n: u32, prod: f64, } - + impl GeometricMean { pub fn new() -> Self { GeometricMean { n: 0, prod: 1.0 } } } - + impl Accumulator for GeometricMean { fn state(&mut self) -> Result> { Ok(vec![ @@ -2348,12 +2347,12 @@ mod tests { ScalarValue::from(self.n), ]) } - + fn evaluate(&mut self) -> Result { let value = self.prod.powf(1.0 / self.n as f64); Ok(ScalarValue::from(value)) } - + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { if values.is_empty() { return Ok(()); @@ -2361,7 +2360,7 @@ mod tests { let arr = &values[0]; (0..arr.len()).try_for_each(|index| { let v = ScalarValue::try_from_array(arr, index)?; - + if let ScalarValue::Float64(Some(value)) = v { self.prod *= value; self.n += 1; @@ -2371,7 +2370,7 @@ mod tests { Ok(()) }) } - + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { if states.is_empty() { return Ok(()); @@ -2382,8 +2381,10 @@ mod tests { .iter() .map(|array| ScalarValue::try_from_array(array, index)) .collect::>>()?; - if let (ScalarValue::Float64(Some(prod)), ScalarValue::UInt32(Some(n))) = - (&v[0], &v[1]) + if let ( + ScalarValue::Float64(Some(prod)), + ScalarValue::UInt32(Some(n)), + ) = (&v[0], &v[1]) { self.prod *= prod; self.n += n; @@ -2393,7 +2394,7 @@ mod tests { Ok(()) }) } - + fn size(&self) -> usize { std::mem::size_of_val(self) } @@ -2435,7 +2436,7 @@ mod tests { fn uses_window_frame(&self) -> bool { true } - + fn evaluate( &mut self, values: &[ArrayRef], @@ -2443,14 +2444,15 @@ mod tests { ) -> Result { let arr: &Float64Array = values[0].as_ref().as_primitive::(); let range_len = range.end - range.start; - + let output = if range_len > 0 { - let sum: f64 = arr.values().iter().skip(range.start).take(range_len).sum(); + let sum: f64 = + arr.values().iter().skip(range.start).take(range_len).sum(); Some(sum / range_len as f64) } else { None }; - + Ok(ScalarValue::Float64(output)) } } From 5758d1b026206aa995c66732a1d462fe5aa39dee Mon Sep 17 00:00:00 2001 From: Eddy Oyieko <67474838+mobley-trent@users.noreply.github.com> Date: Tue, 27 Feb 2024 17:23:50 +0300 Subject: [PATCH 7/9] Update datafusion/core/src/execution/context/mod.rs Co-authored-by: Andrew Lamb --- datafusion/core/src/execution/context/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 0fc1fd1b3db5..adee33716b8d 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -858,7 +858,7 @@ impl SessionContext { self.state.write().deregister_udaf(name).ok(); } - /// Deregisters a UDTF within this context. + /// Deregisters a UDWF within this context. pub fn deregister_udwf(&self, name: &str) { self.state.write().deregister_udwf(name).ok(); } From 4a4ea260eb38dbe69a5bf4def3931c2dd43bbcf1 Mon Sep 17 00:00:00 2001 From: Eddy Oyieko <67474838+mobley-trent@users.noreply.github.com> Date: Tue, 27 Feb 2024 18:21:05 +0300 Subject: [PATCH 8/9] Moved tests to core/tests/user_defined --- datafusion/core/src/execution/context/mod.rs | 202 ------------------ .../user_defined/user_defined_aggregates.rs | 23 ++ .../user_defined_scalar_functions.rs | 16 ++ .../user_defined_window_functions.rs | 15 ++ 4 files changed, 54 insertions(+), 202 deletions(-) diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index adee33716b8d..8cc6b5a5cffc 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -2181,20 +2181,8 @@ mod tests { use crate::test; use crate::test_util::{plan_and_collect, populate_csv_partitions}; use crate::variable::VarType; - use crate::{ - arrow::{ - array::{ArrayRef, AsArray, Float64Array, Int64Array}, - datatypes::{DataType, Float64Type}, - }, - logical_expr::Volatility, - physical_plan::Accumulator, - scalar::ScalarValue, - }; use async_trait::async_trait; - use datafusion_common::cast::as_int64_array; - use datafusion_expr::expr_fn::{create_udaf, create_udf, create_udwf}; use datafusion_expr::Expr; - use datafusion_expr::{ColumnarValue, PartitionEvaluator}; use std::env; use std::path::PathBuf; use std::sync::Weak; @@ -2292,196 +2280,6 @@ mod tests { Ok(()) } - #[tokio::test] - async fn register_deregister_udf() -> Result<()> { - let add = Arc::new(|args: &[ColumnarValue]| { - let args = ColumnarValue::values_to_arrays(args)?; - let i64s = as_int64_array(&args[0])?; - - let array = i64s - .iter() - .map(|array_elem| array_elem.map(|value| value + 1)) - .collect::(); - - Ok(ColumnarValue::from(Arc::new(array) as ArrayRef)) - }); - - let udf = create_udf( - "add", - vec![DataType::Int64], - Arc::new(DataType::Int64), - Volatility::Immutable, - add, - ); - - let ctx = SessionContext::new(); - ctx.register_udf(udf.clone()); - - assert!(ctx.udfs().contains("add")); - - ctx.deregister_udf("add"); - - assert!(!ctx.udfs().contains("add")); - - Ok(()) - } - - #[tokio::test] - async fn register_deregister_udaf() -> Result<()> { - #[derive(Debug)] - struct GeometricMean { - n: u32, - prod: f64, - } - - impl GeometricMean { - pub fn new() -> Self { - GeometricMean { n: 0, prod: 1.0 } - } - } - - impl Accumulator for GeometricMean { - fn state(&mut self) -> Result> { - Ok(vec![ - ScalarValue::from(self.prod), - ScalarValue::from(self.n), - ]) - } - - fn evaluate(&mut self) -> Result { - let value = self.prod.powf(1.0 / self.n as f64); - Ok(ScalarValue::from(value)) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - if values.is_empty() { - return Ok(()); - } - let arr = &values[0]; - (0..arr.len()).try_for_each(|index| { - let v = ScalarValue::try_from_array(arr, index)?; - - if let ScalarValue::Float64(Some(value)) = v { - self.prod *= value; - self.n += 1; - } else { - unreachable!("") - } - Ok(()) - }) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - if states.is_empty() { - return Ok(()); - } - let arr = &states[0]; - (0..arr.len()).try_for_each(|index| { - let v = states - .iter() - .map(|array| ScalarValue::try_from_array(array, index)) - .collect::>>()?; - if let ( - ScalarValue::Float64(Some(prod)), - ScalarValue::UInt32(Some(n)), - ) = (&v[0], &v[1]) - { - self.prod *= prod; - self.n += n; - } else { - unreachable!("") - } - Ok(()) - }) - } - - fn size(&self) -> usize { - std::mem::size_of_val(self) - } - } - - let udaf = create_udaf( - "geo_mean", - vec![DataType::Float64], - Arc::new(DataType::Float64), - Volatility::Immutable, - Arc::new(|_| Ok(Box::new(GeometricMean::new()))), - Arc::new(vec![DataType::Float64, DataType::UInt32]), - ); - - let ctx = SessionContext::new(); - ctx.register_udaf(udaf.clone()); - - assert!(ctx.state().aggregate_functions.contains_key("geo_mean")); - - ctx.deregister_udaf("geo_mean"); - - assert!(!ctx.state().aggregate_functions.contains_key("geo_mean")); - - Ok(()) - } - - #[tokio::test] - async fn register_deregister_udwf() -> Result<()> { - #[derive(Clone, Debug)] - struct MyPartitionEvaluator {} - - impl MyPartitionEvaluator { - fn new() -> Self { - Self {} - } - } - - impl PartitionEvaluator for MyPartitionEvaluator { - fn uses_window_frame(&self) -> bool { - true - } - - fn evaluate( - &mut self, - values: &[ArrayRef], - range: &std::ops::Range, - ) -> Result { - let arr: &Float64Array = values[0].as_ref().as_primitive::(); - let range_len = range.end - range.start; - - let output = if range_len > 0 { - let sum: f64 = - arr.values().iter().skip(range.start).take(range_len).sum(); - Some(sum / range_len as f64) - } else { - None - }; - - Ok(ScalarValue::Float64(output)) - } - } - - fn make_partition_evaluator() -> Result> { - Ok(Box::new(MyPartitionEvaluator::new())) - } - - let smooth_it = create_udwf( - "smooth_it", - DataType::Float64, - Arc::new(DataType::Float64), - Volatility::Immutable, - Arc::new(make_partition_evaluator), - ); - - let ctx = SessionContext::new(); - - ctx.register_udwf(smooth_it.clone()); - - assert!(ctx.state().window_functions.contains_key("smooth_it")); - - ctx.deregister_udwf("smooth_it"); - - assert!(!ctx.state().window_functions.contains_key("smooth_it")); - - Ok(()) - } - #[tokio::test] async fn send_context_to_threads() -> Result<()> { // ensure SessionContexts can be used in a multi-threaded diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 0b29ad10d670..8daeefd236f7 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -255,6 +255,29 @@ async fn simple_udaf() -> Result<()> { Ok(()) } +#[tokio::test] +async fn deregister_udaf() -> Result<()> { + let ctx = SessionContext::new(); + let my_avg = create_udaf( + "my_avg", + vec![DataType::Float64], + Arc::new(DataType::Float64), + Volatility::Immutable, + Arc::new(|_| Ok(Box::::default())), + Arc::new(vec![DataType::UInt64, DataType::Float64]), + ); + + ctx.register_udaf(my_avg.clone()); + + assert!(ctx.state().aggregate_functions().contains_key("my_avg")); + + ctx.deregister_udaf("my_avg"); + + assert!(!ctx.state().aggregate_functions().contains_key("my_avg")); + + Ok(()) +} + #[tokio::test] async fn case_sensitive_identifiers_user_defined_aggregates() -> Result<()> { let ctx = SessionContext::new(); diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index a86c76b9b6dd..37fa297807c3 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -494,6 +494,22 @@ async fn test_user_defined_functions_zero_argument() -> Result<()> { Ok(()) } +#[tokio::test] +async fn deregister_udf() -> Result<()> { + let random_normal_udf = ScalarUDF::from(RandomUDF::new()); + let ctx = SessionContext::new(); + + ctx.register_udf(random_normal_udf.clone()); + + assert!(ctx.udfs().contains("random_udf")); + + ctx.deregister_udf("random_udf"); + + assert!(!ctx.udfs().contains("random_udf")); + + Ok(()) +} + fn create_udf_context() -> SessionContext { let ctx = SessionContext::new(); // register a custom UDF diff --git a/datafusion/core/tests/user_defined/user_defined_window_functions.rs b/datafusion/core/tests/user_defined/user_defined_window_functions.rs index 54eab4315a97..cfd74f8861e3 100644 --- a/datafusion/core/tests/user_defined/user_defined_window_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_window_functions.rs @@ -103,6 +103,21 @@ async fn test_udwf() { assert_eq!(test_state.evaluate_all_called(), 2); } +#[tokio::test] +async fn test_deregister_udwf() -> Result<()> { + let test_state = Arc::new(TestState::new()); + let mut ctx = SessionContext::new(); + OddCounter::register(&mut ctx, Arc::clone(&test_state)); + + assert!(ctx.state().window_functions().contains_key("odd_counter")); + + ctx.deregister_udwf("odd_counter"); + + assert!(!ctx.state().window_functions().contains_key("odd_counter")); + + Ok(()) +} + /// Basic user defined window function with bounded window #[tokio::test] async fn test_udwf_bounded_window_ignores_frame() { From cfcfcb0e8bf1173c1974d259112cce1cc05e374b Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 27 Feb 2024 14:18:56 -0500 Subject: [PATCH 9/9] fix fmt --- datafusion/execution/src/registry.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/datafusion/execution/src/registry.rs b/datafusion/execution/src/registry.rs index 7696d4299bcf..6e0a932f0bc5 100644 --- a/datafusion/execution/src/registry.rs +++ b/datafusion/execution/src/registry.rs @@ -69,7 +69,7 @@ pub trait FunctionRegistry { /// Deregisters a [`ScalarUDF`], returning the implementation that was /// deregistered. - /// + /// /// Returns an error (the default) if the function can not be deregistered, /// for example if the registry is read only. fn deregister_udf(&mut self, _name: &str) -> Result>> { @@ -78,7 +78,7 @@ pub trait FunctionRegistry { /// Deregisters a [`AggregateUDF`], returning the implementation that was /// deregistered. - /// + /// /// Returns an error (the default) if the function can not be deregistered, /// for example if the registry is read only. fn deregister_udaf(&mut self, _name: &str) -> Result>> { @@ -87,7 +87,7 @@ pub trait FunctionRegistry { /// Deregisters a [`WindowUDF`], returning the implementation that was /// deregistered. - /// + /// /// Returns an error (the default) if the function can not be deregistered, /// for example if the registry is read only. fn deregister_udwf(&mut self, _name: &str) -> Result>> {