Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat : Support for deregistering user defined functions #9239

Merged
merged 10 commits into from
Feb 27, 2024
33 changes: 33 additions & 0 deletions datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,21 @@ impl SessionContext {
self.state.write().register_udwf(Arc::new(f)).ok();
}

/// Deregisters a UDF within this context.
pub fn deregister_udf(&self, name: &str) {
self.state.write().deregister_udf(name).ok();
}

/// Deregisters a UDAF within this context.
pub fn deregister_udaf(&self, name: &str) {
self.state.write().deregister_udaf(name).ok();
}

/// Deregisters a UDWF within this context.
pub fn deregister_udwf(&self, name: &str) {
self.state.write().deregister_udwf(name).ok();
}

/// Creates a [`DataFrame`] for reading a data source.
///
/// For more control such as reading multiple files, you can use
Expand Down Expand Up @@ -2026,6 +2041,24 @@ impl FunctionRegistry for SessionState {
fn register_udwf(&mut self, udwf: Arc<WindowUDF>) -> Result<Option<Arc<WindowUDF>>> {
Ok(self.window_functions.insert(udwf.name().into(), udwf))
}

fn deregister_udf(&mut self, name: &str) -> Result<Option<Arc<ScalarUDF>>> {
let udf = self.scalar_functions.remove(name);
if let Some(udf) = &udf {
for alias in udf.aliases() {
self.scalar_functions.remove(alias);
}
}
Ok(udf)
}

fn deregister_udaf(&mut self, name: &str) -> Result<Option<Arc<AggregateUDF>>> {
Ok(self.aggregate_functions.remove(name))
}

fn deregister_udwf(&mut self, name: &str) -> Result<Option<Arc<WindowUDF>>> {
Ok(self.window_functions.remove(name))
}
}

impl OptimizerConfig for SessionState {
Expand Down
23 changes: 23 additions & 0 deletions datafusion/core/tests/user_defined/user_defined_aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,29 @@ async fn simple_udaf() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn deregister_udaf() -> Result<()> {
let ctx = SessionContext::new();
let my_avg = create_udaf(
"my_avg",
vec![DataType::Float64],
Arc::new(DataType::Float64),
Volatility::Immutable,
Arc::new(|_| Ok(Box::<AvgAccumulator>::default())),
Arc::new(vec![DataType::UInt64, DataType::Float64]),
);

ctx.register_udaf(my_avg.clone());

assert!(ctx.state().aggregate_functions().contains_key("my_avg"));

ctx.deregister_udaf("my_avg");

assert!(!ctx.state().aggregate_functions().contains_key("my_avg"));

Ok(())
}

#[tokio::test]
async fn case_sensitive_identifiers_user_defined_aggregates() -> Result<()> {
let ctx = SessionContext::new();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,22 @@ async fn test_user_defined_functions_zero_argument() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn deregister_udf() -> Result<()> {
let random_normal_udf = ScalarUDF::from(RandomUDF::new());
let ctx = SessionContext::new();

ctx.register_udf(random_normal_udf.clone());

assert!(ctx.udfs().contains("random_udf"));

ctx.deregister_udf("random_udf");

assert!(!ctx.udfs().contains("random_udf"));

Ok(())
}

#[derive(Debug)]
struct TakeUDF {
signature: Signature,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,21 @@ async fn test_udwf() {
assert_eq!(test_state.evaluate_all_called(), 2);
}

#[tokio::test]
async fn test_deregister_udwf() -> Result<()> {
let test_state = Arc::new(TestState::new());
let mut ctx = SessionContext::new();
OddCounter::register(&mut ctx, Arc::clone(&test_state));

assert!(ctx.state().window_functions().contains_key("odd_counter"));

ctx.deregister_udwf("odd_counter");

assert!(!ctx.state().window_functions().contains_key("odd_counter"));

Ok(())
}

/// Basic user defined window function with bounded window
#[tokio::test]
async fn test_udwf_bounded_window_ignores_frame() {
Expand Down
27 changes: 27 additions & 0 deletions datafusion/execution/src/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,33 @@ pub trait FunctionRegistry {
fn register_udwf(&mut self, _udaf: Arc<WindowUDF>) -> Result<Option<Arc<WindowUDF>>> {
not_impl_err!("Registering WindowUDF")
}

/// Deregisters a [`ScalarUDF`], returning the implementation that was
/// deregistered.
///
/// Returns an error (the default) if the function can not be deregistered,
/// for example if the registry is read only.
fn deregister_udf(&mut self, _name: &str) -> Result<Option<Arc<ScalarUDF>>> {
not_impl_err!("Deregistering ScalarUDF")
}

/// Deregisters a [`AggregateUDF`], returning the implementation that was
/// deregistered.
///
/// Returns an error (the default) if the function can not be deregistered,
/// for example if the registry is read only.
fn deregister_udaf(&mut self, _name: &str) -> Result<Option<Arc<AggregateUDF>>> {
not_impl_err!("Deregistering AggregateUDF")
}

/// Deregisters a [`WindowUDF`], returning the implementation that was
/// deregistered.
///
/// Returns an error (the default) if the function can not be deregistered,
/// for example if the registry is read only.
fn deregister_udwf(&mut self, _name: &str) -> Result<Option<Arc<WindowUDF>>> {
not_impl_err!("Deregistering WindowUDF")
}
}

/// Serializer and deserializer registry for extensions like [UserDefinedLogicalNode].
Expand Down
Loading