Skip to content

Commit

Permalink
feat: issue apache#8969 adding position function
Browse files Browse the repository at this point in the history
  • Loading branch information
Lordworms committed Jan 24, 2024
1 parent 558b3d6 commit e730036
Show file tree
Hide file tree
Showing 10 changed files with 209 additions and 46 deletions.
10 changes: 9 additions & 1 deletion datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,8 @@ pub enum BuiltinScalarFunction {
NullIf,
/// octet_length
OctetLength,
/// position
Position,
/// random
Random,
/// regexp_replace
Expand Down Expand Up @@ -460,6 +462,7 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::MD5 => Volatility::Immutable,
BuiltinScalarFunction::NullIf => Volatility::Immutable,
BuiltinScalarFunction::OctetLength => Volatility::Immutable,
BuiltinScalarFunction::Position => Volatility::Immutable,
BuiltinScalarFunction::Radians => Volatility::Immutable,
BuiltinScalarFunction::RegexpReplace => Volatility::Immutable,
BuiltinScalarFunction::Repeat => Volatility::Immutable,
Expand Down Expand Up @@ -735,6 +738,9 @@ impl BuiltinScalarFunction {
utf8_to_int_type(&input_expr_types[0], "octet_length")
}
BuiltinScalarFunction::Pi => Ok(Float64),
BuiltinScalarFunction::Position => {
utf8_to_int_type(&input_expr_types[0], "position")
}
BuiltinScalarFunction::Random => Ok(Float64),
BuiltinScalarFunction::Uuid => Ok(Utf8),
BuiltinScalarFunction::RegexpReplace => {
Expand Down Expand Up @@ -1225,7 +1231,8 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::EndsWith
| BuiltinScalarFunction::InStr
| BuiltinScalarFunction::Strpos
| BuiltinScalarFunction::StartsWith => Signature::one_of(
| BuiltinScalarFunction::StartsWith
| BuiltinScalarFunction::Position => Signature::one_of(
vec![
Exact(vec![Utf8, Utf8]),
Exact(vec![Utf8, LargeUtf8]),
Expand Down Expand Up @@ -1498,6 +1505,7 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Reverse => &["reverse"],
BuiltinScalarFunction::Right => &["right"],
BuiltinScalarFunction::Rpad => &["rpad"],
BuiltinScalarFunction::Position => &["position"],
BuiltinScalarFunction::Rtrim => &["rtrim"],
BuiltinScalarFunction::SplitPart => &["split_part"],
BuiltinScalarFunction::StringToArray => {
Expand Down
6 changes: 6 additions & 0 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,12 @@ scalar_expr!(
string,
"returns the number of bytes of a string"
);
scalar_expr!(
Position,
position,
substring string,
"return the position of the appearence of `substring` in `string`"
);
scalar_expr!(Replace, replace, string from to, "replaces all occurrences of `from` with `to` in the `string`");
scalar_expr!(Repeat, repeat, string n, "repeats the `string` to `n` times");
scalar_expr!(Reverse, reverse, string, "reverses the `string`");
Expand Down
11 changes: 11 additions & 0 deletions datafusion/physical-expr/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -892,6 +892,17 @@ pub fn create_physical_fun(
"Unsupported data type {other:?} for function overlay",
))),
}),
BuiltinScalarFunction::Position => Arc::new(|args| match args[0].data_type() {
DataType::Utf8 => {
make_scalar_function_inner(string_expressions::position::<i32>)(args)
}
DataType::Utf8 => {
make_scalar_function_inner(string_expressions::position::<i64>)(args)
}
other => Err(DataFusionError::Internal(format!(
"Unsupported data type {other:?} for function position"
))),
}),
BuiltinScalarFunction::Levenshtein => {
Arc::new(|args| match args[0].data_type() {
DataType::Utf8 => make_scalar_function_inner(
Expand Down
53 changes: 53 additions & 0 deletions datafusion/physical-expr/src/string_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,48 @@ pub fn uuid(args: &[ColumnarValue]) -> Result<ColumnarValue> {
let array = GenericStringArray::<i32>::from_iter_values(values);
Ok(ColumnarValue::Array(Arc::new(array)))
}
/// position function, similar logic as instr
/// position('world' in 'Helloworld') = 6
pub fn position<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
let substr_arr = as_generic_string_array::<T>(&args[0])?;
let str_arr = as_generic_string_array::<T>(&args[1])?;

match args[0].data_type() {
DataType::Utf8 => {
let result = str_arr
.iter()
.zip(substr_arr.iter())
.map(|(string, substr)| match (string, substr) {
(Some(string), Some(substr)) => string
.find(substr)
.map_or(Some(0), |index| Some((index + 1) as i32)),
_ => None,
})
.collect::<Int32Array>();

Ok(Arc::new(result) as ArrayRef)
}
DataType::LargeUtf8 => {
let result = str_arr
.iter()
.zip(substr_arr.iter())
.map(|(string, substr)| match (string, substr) {
(Some(string), Some(substr)) => string
.find(substr)
.map_or(Some(0), |index| Some((index + 1) as i64)),
_ => None,
})
.collect::<Int64Array>();

Ok(Arc::new(result) as ArrayRef)
}
other => {
internal_err!(
"position was called with {other} datatype arguments. It requires Utf8 or LargeUtf8."
)
}
}
}

/// OVERLAY(string1 PLACING string2 FROM integer FOR integer2)
/// Replaces a substring of string1 with string2 starting at the integer bit
Expand Down Expand Up @@ -787,4 +829,15 @@ mod tests {

Ok(())
}
#[test]
fn to_position() -> Result<()> {
let substr_arr = Arc::new(StringArray::from(vec!["world"]));
let str_arr = Arc::new(StringArray::from(vec!["Hello, world"]));
let res = position::<i32>(&[substr_arr, str_arr]).unwrap();
let result =
as_int32_array(&res).expect("failed to initialized function position");
let expected = Int32Array::from(vec![8]);
assert_eq!(&expected, result);
Ok(())
}
}
Loading

0 comments on commit e730036

Please sign in to comment.