Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

set keys to null where applicable in dictionary-encoded results #46

Merged
merged 5 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 76 additions & 33 deletions src/common.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
use std::str::Utf8Error;
use std::sync::Arc;

use datafusion::arrow::array::{
Array, ArrayRef, AsArray, Int64Array, LargeStringArray, StringArray, StringViewArray, UInt64Array,
Array, ArrayRef, AsArray, DictionaryArray, Int64Array, LargeStringArray, PrimitiveArray, StringArray,
StringViewArray, UInt64Array, UnionArray,
};
use datafusion::arrow::datatypes::DataType;
use datafusion::arrow::datatypes::{ArrowNativeType, ArrowPrimitiveType, DataType};
use datafusion::arrow::downcast_dictionary_array;
use datafusion::common::{exec_err, plan_err, Result as DataFusionResult, ScalarValue};
use datafusion::logical_expr::ColumnarValue;
use jiter::{Jiter, JiterError, Peek};

use crate::common_union::{is_json_union, json_from_union_scalar, nested_json_array};
use crate::common_union::{is_json_union, json_from_union_scalar, nested_json_array, TYPE_ID_NULL};

/// General implementation of `ScalarUDFImpl::return_type`.
///
Expand Down Expand Up @@ -164,21 +167,32 @@ fn zip_apply<'a, P: Iterator<Item = Option<JsonPath<'a>>>, C: FromIterator<Optio
jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result<I, GetError>,
object_lookup: bool,
) -> DataFusionResult<ArrayRef> {
if let Some(d) = json_array.as_any_dictionary_opt() {
let a = zip_apply(d.values(), path_array, to_array, jiter_find, object_lookup)?;
return Ok(d.with_values(a).into());
}
let c = if let Some(string_array) = json_array.as_any().downcast_ref::<StringArray>() {
zip_apply_iter(string_array.iter(), path_array, jiter_find)
} else if let Some(large_string_array) = json_array.as_any().downcast_ref::<LargeStringArray>() {
zip_apply_iter(large_string_array.iter(), path_array, jiter_find)
} else if let Some(string_view) = json_array.as_any().downcast_ref::<StringViewArray>() {
zip_apply_iter(string_view.iter(), path_array, jiter_find)
} else if let Some(string_array) = nested_json_array(json_array, object_lookup) {
zip_apply_iter(string_array.iter(), path_array, jiter_find)
} else {
return exec_err!("unexpected json array type {:?}", json_array.data_type());
};
// arrow_schema "use" is workaround for https://github.com/apache/arrow-rs/issues/6400#issue-2528388332
use datafusion::arrow::datatypes as arrow_schema;

let c = downcast_dictionary_array!(
json_array => {
let values = zip_apply(json_array.values(), path_array, to_array, jiter_find, object_lookup)?;
if !is_json_union(values.data_type()) {
return Ok(Arc::new(json_array.with_values(values)));
}
// JSON union: post-process the array to set keys to null where the union member is null
let type_ids = values.as_any().downcast_ref::<UnionArray>().unwrap().type_ids();
return Ok(Arc::new(DictionaryArray::new(
mask_dictionary_keys(json_array.keys(), type_ids),
values,
)));
}
DataType::Utf8 => zip_apply_iter(json_array.as_string::<i32>().iter(), path_array, jiter_find),
DataType::LargeUtf8 => zip_apply_iter(json_array.as_string::<i64>().iter(), path_array, jiter_find),
DataType::Utf8View => zip_apply_iter(json_array.as_string_view().iter(), path_array, jiter_find),
other => if let Some(string_array) = nested_json_array(json_array, object_lookup) {
zip_apply_iter(string_array.iter(), path_array, jiter_find)
} else {
return exec_err!("unexpected json array type {:?}", other);
}
);

to_array(c)
}

Expand Down Expand Up @@ -229,22 +243,31 @@ fn scalar_apply<C: FromIterator<Option<I>>, I>(
to_array: impl Fn(C) -> DataFusionResult<ArrayRef>,
jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result<I, GetError>,
) -> DataFusionResult<ArrayRef> {
if let Some(d) = json_array.as_any_dictionary_opt() {
let a = scalar_apply(d.values(), path, to_array, jiter_find)?;
return Ok(d.with_values(a).into());
}
// arrow_schema "use" is workaround for https://github.com/apache/arrow-rs/issues/6400#issue-2528388332
use datafusion::arrow::datatypes as arrow_schema;

let c = if let Some(string_array) = json_array.as_any().downcast_ref::<StringArray>() {
scalar_apply_iter(string_array.iter(), path, jiter_find)
} else if let Some(large_string_array) = json_array.as_any().downcast_ref::<LargeStringArray>() {
scalar_apply_iter(large_string_array.iter(), path, jiter_find)
} else if let Some(string_view_array) = json_array.as_any().downcast_ref::<StringViewArray>() {
scalar_apply_iter(string_view_array.iter(), path, jiter_find)
} else if let Some(string_array) = nested_json_array(json_array, is_object_lookup(path)) {
scalar_apply_iter(string_array.iter(), path, jiter_find)
} else {
return exec_err!("unexpected json array type {:?}", json_array.data_type());
};
let c = downcast_dictionary_array!(
json_array => {
let values = scalar_apply(json_array.values(), path, to_array, jiter_find)?;
if !is_json_union(values.data_type()) {
return Ok(Arc::new(json_array.with_values(values)));
}
// JSON union: post-process the array to set keys to null where the union member is null
let type_ids = values.as_any().downcast_ref::<UnionArray>().unwrap().type_ids();
return Ok(Arc::new(DictionaryArray::new(
mask_dictionary_keys(json_array.keys(), type_ids),
values,
)));
}
DataType::Utf8 => scalar_apply_iter(json_array.as_string::<i32>().iter(), path, jiter_find),
DataType::LargeUtf8 => scalar_apply_iter(json_array.as_string::<i64>().iter(), path, jiter_find),
DataType::Utf8View => scalar_apply_iter(json_array.as_string_view().iter(), path, jiter_find),
other => if let Some(string_array) = nested_json_array(json_array, is_object_lookup(path)) {
scalar_apply_iter(string_array.iter(), path, jiter_find)
} else {
return exec_err!("unexpected json array type {:?}", other);
}
);

to_array(c)
}
Expand Down Expand Up @@ -319,3 +342,23 @@ impl From<Utf8Error> for GetError {
GetError
}
}

/// Set keys to null where the union member is null.
///
/// This is a workaround to <https://github.com/apache/arrow-rs/issues/6017#issuecomment-2352756753>
/// - i.e. that dictionary null is most reliably done if the keys are null.
///
/// That said, doing this might also be an optimization for cases like null-checking without needing
/// to check the value union array.
fn mask_dictionary_keys<K: ArrowPrimitiveType>(keys: &PrimitiveArray<K>, type_ids: &[i8]) -> PrimitiveArray<K> {
let mut null_mask = vec![true; keys.len()];
for (i, k) in keys.iter().enumerate() {
match k {
// if the key is non-null and value is non-null, don't mask it out
Some(k) if type_ids[k.as_usize()] != TYPE_ID_NULL => {}
// i.e. key is null or value is null here
_ => null_mask[i] = false,
}
}
PrimitiveArray::new(keys.values().clone(), Some(null_mask.into()))
}
2 changes: 1 addition & 1 deletion src/common_union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ pub(crate) enum JsonUnionField {
Object(String),
}

const TYPE_ID_NULL: i8 = 0;
pub(crate) const TYPE_ID_NULL: i8 = 0;
const TYPE_ID_BOOL: i8 = 1;
const TYPE_ID_INT: i8 = 2;
const TYPE_ID_FLOAT: i8 = 3;
Expand Down
2 changes: 1 addition & 1 deletion src/rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ fn unnest_json_calls(func: &ScalarFunction) -> Option<Transformed<Expr>> {
fn extract_scalar_function(expr: &Expr) -> Option<&ScalarFunction> {
match expr {
Expr::ScalarFunction(func) => Some(func),
Expr::Alias(alias) => extract_scalar_function(&*alias.expr),
Expr::Alias(alias) => extract_scalar_function(&alias.expr),
_ => None,
}
}
Expand Down
54 changes: 49 additions & 5 deletions tests/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use datafusion::common::ScalarValue;
use datafusion::logical_expr::ColumnarValue;

use datafusion_functions_json::udfs::json_get_str_udf;
use utils::{display_val, logical_plan, run_query, run_query_large, run_query_params};
use utils::{display_val, logical_plan, run_query, run_query_dict, run_query_large, run_query_params};

mod utils;

Expand Down Expand Up @@ -1072,6 +1072,28 @@ async fn test_arrow_union_is_null() {
assert_batches_eq!(expected, &batches);
}

#[tokio::test]
async fn test_arrow_union_is_null_dict_encoded() {
let batches = run_query_dict("select name, (json_data->'foo') is null from test")
.await
.unwrap();

let expected = [
"+------------------+---------------------------------------+",
"| name | test.json_data -> Utf8(\"foo\") IS NULL |",
"+------------------+---------------------------------------+",
"| object_foo | false |",
"| object_foo_array | false |",
"| object_foo_obj | false |",
"| object_foo_null | true |",
"| object_bar | true |",
"| list_foo | true |",
"| invalid_json | true |",
"+------------------+---------------------------------------+",
];
assert_batches_eq!(expected, &batches);
}

#[tokio::test]
async fn test_arrow_union_is_not_null() {
let batches = run_query("select name, (json_data->'foo') is not null from test")
Expand All @@ -1094,6 +1116,28 @@ async fn test_arrow_union_is_not_null() {
assert_batches_eq!(expected, &batches);
}

#[tokio::test]
async fn test_arrow_union_is_not_null_dict_encoded() {
let batches = run_query_dict("select name, (json_data->'foo') is not null from test")
.await
.unwrap();

let expected = [
"+------------------+-------------------------------------------+",
"| name | test.json_data -> Utf8(\"foo\") IS NOT NULL |",
"+------------------+-------------------------------------------+",
"| object_foo | true |",
"| object_foo_array | true |",
"| object_foo_obj | true |",
"| object_foo_null | false |",
"| object_bar | false |",
"| list_foo | false |",
"| invalid_json | false |",
"+------------------+-------------------------------------------+",
];
assert_batches_eq!(expected, &batches);
}

#[tokio::test]
async fn test_arrow_scalar_union_is_null() {
let batches = run_query(
Expand Down Expand Up @@ -1147,8 +1191,8 @@ async fn test_dict_haystack() {
"| v |",
"+-----------------------+",
"| {object={\"bar\": [0]}} |",
"| {null=} |",
"| {null=} |",
"| |",
"| |",
"+-----------------------+",
];

Expand All @@ -1164,8 +1208,8 @@ async fn test_dict_haystack_needle() {
"| v |",
"+-------------+",
"| {array=[0]} |",
"| {null=} |",
"| {null=} |",
"| |",
"| |",
"+-------------+",
];

Expand Down
28 changes: 21 additions & 7 deletions tests/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
use std::sync::Arc;

use datafusion::arrow::array::{
ArrayRef, DictionaryArray, Int64Array, StringViewArray, UInt32Array, UInt64Array, UInt8Array,
ArrayRef, DictionaryArray, Int32Array, Int64Array, StringViewArray, UInt32Array, UInt64Array, UInt8Array,
};
use datafusion::arrow::datatypes::{DataType, Field, Int64Type, Schema, UInt32Type, UInt8Type};
use datafusion::arrow::datatypes::{DataType, Field, Int32Type, Int64Type, Schema, UInt32Type, UInt8Type};
use datafusion::arrow::util::display::{ArrayFormatter, FormatOptions};
use datafusion::arrow::{array::LargeStringArray, array::StringArray, record_batch::RecordBatch};
use datafusion::common::ParamValues;
Expand All @@ -13,7 +13,7 @@ use datafusion::execution::context::SessionContext;
use datafusion::prelude::SessionConfig;
use datafusion_functions_json::register_all;

async fn create_test_table(large_utf8: bool) -> Result<SessionContext> {
async fn create_test_table(large_utf8: bool, dict_encoded: bool) -> Result<SessionContext> {
let config = SessionConfig::new().set_str("datafusion.sql_parser.dialect", "postgres");
let mut ctx = SessionContext::new_with_config(config);
register_all(&mut ctx)?;
Expand All @@ -28,11 +28,20 @@ async fn create_test_table(large_utf8: bool) -> Result<SessionContext> {
("invalid_json", "is not json"),
];
let json_values = test_data.iter().map(|(_, json)| *json).collect::<Vec<_>>();
let (json_data_type, json_array): (DataType, ArrayRef) = if large_utf8 {
let (mut json_data_type, mut json_array): (DataType, ArrayRef) = if large_utf8 {
(DataType::LargeUtf8, Arc::new(LargeStringArray::from(json_values)))
} else {
(DataType::Utf8, Arc::new(StringArray::from(json_values)))
};

if dict_encoded {
json_data_type = DataType::Dictionary(DataType::Int32.into(), json_data_type.into());
json_array = Arc::new(DictionaryArray::<Int32Type>::new(
Int32Array::from_iter_values(0..(json_array.len() as i32)),
json_array,
));
}

let test_batch = RecordBatch::try_new(
Arc::new(Schema::new(vec![
Field::new("name", DataType::Utf8, false),
Expand Down Expand Up @@ -178,12 +187,17 @@ async fn create_test_table(large_utf8: bool) -> Result<SessionContext> {
}

pub async fn run_query(sql: &str) -> Result<Vec<RecordBatch>> {
let ctx = create_test_table(false).await?;
let ctx = create_test_table(false, false).await?;
ctx.sql(sql).await?.collect().await
}

pub async fn run_query_large(sql: &str) -> Result<Vec<RecordBatch>> {
let ctx = create_test_table(true).await?;
let ctx = create_test_table(true, false).await?;
ctx.sql(sql).await?.collect().await
}

pub async fn run_query_dict(sql: &str) -> Result<Vec<RecordBatch>> {
let ctx = create_test_table(false, true).await?;
ctx.sql(sql).await?.collect().await
}

Expand All @@ -192,7 +206,7 @@ pub async fn run_query_params(
large_utf8: bool,
query_values: impl Into<ParamValues>,
) -> Result<Vec<RecordBatch>> {
let ctx = create_test_table(large_utf8).await?;
let ctx = create_test_table(large_utf8, false).await?;
ctx.sql(sql).await?.with_param_values(query_values)?.collect().await
}

Expand Down
Loading