Skip to content

Commit

Permalink
Merge remote-tracking branch 'apache/main' into alamb/unparser_docs
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Jul 12, 2024
2 parents eba6e60 + 5ba634a commit 8ca435d
Show file tree
Hide file tree
Showing 18 changed files with 1,025 additions and 67 deletions.
1 change: 1 addition & 0 deletions datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

79 changes: 74 additions & 5 deletions datafusion/core/tests/user_defined/user_defined_aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,19 @@
//! This module contains end to end demonstrations of creating
//! user defined aggregate functions

use arrow::{array::AsArray, datatypes::Fields};
use arrow_array::{types::UInt64Type, Int32Array, PrimitiveArray, StructArray};
use arrow_schema::Schema;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};

use arrow::{array::AsArray, datatypes::Fields};
use arrow_array::{
types::UInt64Type, Int32Array, PrimitiveArray, StringArray, StructArray,
};
use arrow_schema::Schema;

use datafusion::dataframe::DataFrame;
use datafusion::datasource::MemTable;
use datafusion::test_util::plan_and_collect;
use datafusion::{
Expand All @@ -45,8 +50,8 @@ use datafusion::{
};
use datafusion_common::{assert_contains, cast::as_primitive_array, exec_err};
use datafusion_expr::{
create_udaf, function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator,
SimpleAggregateUDF,
col, create_udaf, function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator,
LogicalPlanBuilder, SimpleAggregateUDF,
};
use datafusion_functions_aggregate::average::AvgAccumulator;

Expand Down Expand Up @@ -377,6 +382,55 @@ async fn test_groups_accumulator() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn test_parameterized_aggregate_udf() -> Result<()> {
let batch = RecordBatch::try_from_iter([(
"text",
Arc::new(StringArray::from(vec!["foo"])) as ArrayRef,
)])?;

let ctx = SessionContext::new();
ctx.register_batch("t", batch)?;
let t = ctx.table("t").await?;
let signature = Signature::exact(vec![DataType::Utf8], Volatility::Immutable);
let udf1 = AggregateUDF::from(TestGroupsAccumulator {
signature: signature.clone(),
result: 1,
});
let udf2 = AggregateUDF::from(TestGroupsAccumulator {
signature: signature.clone(),
result: 2,
});

let plan = LogicalPlanBuilder::from(t.into_optimized_plan()?)
.aggregate(
[col("text")],
[
udf1.call(vec![col("text")]).alias("a"),
udf2.call(vec![col("text")]).alias("b"),
],
)?
.build()?;

assert_eq!(
format!("{plan:?}"),
"Aggregate: groupBy=[[t.text]], aggr=[[geo_mean(t.text) AS a, geo_mean(t.text) AS b]]\n TableScan: t projection=[text]"
);

let actual = DataFrame::new(ctx.state(), plan).collect().await?;
let expected = [
"+------+---+---+",
"| text | a | b |",
"+------+---+---+",
"| foo | 1 | 2 |",
"+------+---+---+",
];
assert_batches_eq!(expected, &actual);

ctx.deregister_table("t")?;
Ok(())
}

/// Returns an context with a table "t" and the "first" and "time_sum"
/// aggregate functions registered.
///
Expand Down Expand Up @@ -735,6 +789,21 @@ impl AggregateUDFImpl for TestGroupsAccumulator {
) -> Result<Box<dyn GroupsAccumulator>> {
Ok(Box::new(self.clone()))
}

fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
if let Some(other) = other.as_any().downcast_ref::<TestGroupsAccumulator>() {
self.result == other.result && self.signature == other.signature
} else {
false
}
}

fn hash_value(&self) -> u64 {
let hasher = &mut DefaultHasher::new();
self.signature.hash(hasher);
self.result.hash(hasher);
hasher.finish()
}
}

impl Accumulator for TestGroupsAccumulator {
Expand Down
128 changes: 125 additions & 3 deletions datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,20 @@
// under the License.

use std::any::Any;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::sync::Arc;

use arrow::compute::kernels::numeric::add;
use arrow_array::{ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch};
use arrow_array::builder::BooleanBuilder;
use arrow_array::cast::AsArray;
use arrow_array::{
Array, ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch, StringArray,
};
use arrow_schema::{DataType, Field, Schema};
use parking_lot::Mutex;
use regex::Regex;
use sqlparser::ast::Ident;

use datafusion::execution::context::{FunctionFactory, RegisterFunction, SessionState};
use datafusion::prelude::*;
use datafusion::{execution::registry::FunctionRegistry, test_util};
Expand All @@ -37,8 +46,6 @@ use datafusion_expr::{
Volatility,
};
use datafusion_functions_array::range::range_udf;
use parking_lot::Mutex;
use sqlparser::ast::Ident;

/// test that casting happens on udfs.
/// c11 is f32, but `custom_sqrt` requires f64. Casting happens but the logical plan and
Expand Down Expand Up @@ -1021,6 +1028,121 @@ async fn create_scalar_function_from_sql_statement_postgres_syntax() -> Result<(
Ok(())
}

#[derive(Debug)]
struct MyRegexUdf {
signature: Signature,
regex: Regex,
}

impl MyRegexUdf {
fn new(pattern: &str) -> Self {
Self {
signature: Signature::exact(vec![DataType::Utf8], Volatility::Immutable),
regex: Regex::new(pattern).expect("regex"),
}
}

fn matches(&self, value: Option<&str>) -> Option<bool> {
Some(self.regex.is_match(value?))
}
}

impl ScalarUDFImpl for MyRegexUdf {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"regex_udf"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, args: &[DataType]) -> Result<DataType> {
if matches!(args, [DataType::Utf8]) {
Ok(DataType::Boolean)
} else {
plan_err!("regex_udf only accepts a Utf8 argument")
}
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
match args {
[ColumnarValue::Scalar(ScalarValue::Utf8(value))] => {
Ok(ColumnarValue::Scalar(ScalarValue::Boolean(
self.matches(value.as_deref()),
)))
}
[ColumnarValue::Array(values)] => {
let mut builder = BooleanBuilder::with_capacity(values.len());
for value in values.as_string::<i32>() {
builder.append_option(self.matches(value))
}
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
}
_ => exec_err!("regex_udf only accepts a Utf8 arguments"),
}
}

fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
if let Some(other) = other.as_any().downcast_ref::<MyRegexUdf>() {
self.regex.as_str() == other.regex.as_str()
} else {
false
}
}

fn hash_value(&self) -> u64 {
let hasher = &mut DefaultHasher::new();
self.regex.as_str().hash(hasher);
hasher.finish()
}
}

#[tokio::test]
async fn test_parameterized_scalar_udf() -> Result<()> {
let batch = RecordBatch::try_from_iter([(
"text",
Arc::new(StringArray::from(vec!["foo", "bar", "foobar", "barfoo"])) as ArrayRef,
)])?;

let ctx = SessionContext::new();
ctx.register_batch("t", batch)?;
let t = ctx.table("t").await?;
let foo_udf = ScalarUDF::from(MyRegexUdf::new("fo{2}"));
let bar_udf = ScalarUDF::from(MyRegexUdf::new("[Bb]ar"));

let plan = LogicalPlanBuilder::from(t.into_optimized_plan()?)
.filter(
foo_udf
.call(vec![col("text")])
.and(bar_udf.call(vec![col("text")])),
)?
.filter(col("text").is_not_null())?
.build()?;

assert_eq!(
format!("{plan:?}"),
"Filter: t.text IS NOT NULL\n Filter: regex_udf(t.text) AND regex_udf(t.text)\n TableScan: t projection=[text]"
);

let actual = DataFrame::new(ctx.state(), plan).collect().await?;
let expected = [
"+--------+",
"| text |",
"+--------+",
"| foobar |",
"| barfoo |",
"+--------+",
];
assert_batches_eq!(expected, &actual);

ctx.deregister_table("t")?;
Ok(())
}

fn create_udf_context() -> SessionContext {
let ctx = SessionContext::new();
// register a custom UDF
Expand Down
7 changes: 7 additions & 0 deletions datafusion/expr/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,13 @@ pub trait ExprPlanner: Send + Sync {
) -> Result<PlannerResult<Vec<Expr>>> {
Ok(PlannerResult::Original(args))
}

/// Plans an overlay expression eg `overlay(str PLACING substr FROM pos [FOR count])`
///
/// Returns origin expression arguments if not possible
fn plan_overlay(&self, args: Vec<Expr>) -> Result<PlannerResult<Vec<Expr>>> {
Ok(PlannerResult::Original(args))
}
}

/// An operator with two arguments to plan
Expand Down
Loading

0 comments on commit 8ca435d

Please sign in to comment.