From b7cf00b027806f5dfac6e447f4eaf26a482333e8 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Fri, 27 Sep 2024 00:59:33 +0100 Subject: [PATCH 1/2] using encoding --- src/common_union.rs | 111 ++++++++++++++++++++++++++++++++++++++++++-- src/lib.rs | 2 + 2 files changed, 109 insertions(+), 4 deletions(-) diff --git a/src/common_union.rs b/src/common_union.rs index 8fcbeba..9538685 100644 --- a/src/common_union.rs +++ b/src/common_union.rs @@ -3,11 +3,12 @@ use std::sync::{Arc, OnceLock}; use datafusion::arrow::array::{ Array, ArrayRef, BooleanArray, Float64Array, Int64Array, NullArray, StringArray, UnionArray, }; -use datafusion::arrow::buffer::Buffer; +use datafusion::arrow::buffer::{Buffer, ScalarBuffer}; use datafusion::arrow::datatypes::{DataType, Field, UnionFields, UnionMode}; +use datafusion::arrow::error::ArrowError; use datafusion::common::ScalarValue; -pub(crate) fn is_json_union(data_type: &DataType) -> bool { +pub fn is_json_union(data_type: &DataType) -> bool { match data_type { DataType::Union(fields, UnionMode::Sparse) => fields == &union_fields(), _ => false, @@ -64,7 +65,7 @@ impl JsonUnion { strings: vec![None; length], arrays: vec![None; length], objects: vec![None; length], - type_ids: vec![0; length], + type_ids: vec![TYPE_ID_NULL; length], index: 0, length, } @@ -114,7 +115,7 @@ impl FromIterator> for JsonUnion { } impl TryFrom for UnionArray { - type Error = datafusion::arrow::error::ArrowError; + type Error = ArrowError; fn try_from(value: JsonUnion) -> Result { let children: Vec> = vec![ @@ -199,3 +200,105 @@ impl From for ScalarValue { } } } + +pub struct JsonUnionEncoder { + boolean: BooleanArray, + int: Int64Array, + float: Float64Array, + string: StringArray, + array: StringArray, + object: StringArray, + type_ids: ScalarBuffer, +} + +#[derive(Debug, PartialEq)] +pub enum JsonUnionValue<'a> { + JsonNull, + Bool(bool), + Int(i64), + Float(f64), + Str(&'a str), + Array(&'a str), + Object(&'a str), +} + +impl JsonUnionEncoder { + #[must_use] + pub fn from_union(union: UnionArray) -> Option { + if is_json_union(union.data_type()) { + let (_, type_ids, _, c) = union.into_parts(); + Some(Self { + boolean: c[1].as_any().downcast_ref::().cloned()?, + int: c[2].as_any().downcast_ref::().cloned()?, + float: c[3].as_any().downcast_ref::().cloned()?, + string: c[4].as_any().downcast_ref::().cloned()?, + array: c[5].as_any().downcast_ref::().cloned()?, + object: c[6].as_any().downcast_ref::().cloned()?, + type_ids, + }) + } else { + None + } + } + + /// Get the encodable value for a given index + /// + /// # Panics + /// + /// Panics if the idx is outside the union values or an invalid type id exists in the union. + #[must_use] + pub fn get_value(&self, idx: usize) -> JsonUnionValue { + let type_id = self.type_ids[idx]; + match type_id { + TYPE_ID_NULL => JsonUnionValue::JsonNull, + TYPE_ID_BOOL => JsonUnionValue::Bool(self.boolean.value(idx)), + TYPE_ID_INT => JsonUnionValue::Int(self.int.value(idx)), + TYPE_ID_FLOAT => JsonUnionValue::Float(self.float.value(idx)), + TYPE_ID_STR => JsonUnionValue::Str(self.string.value(idx)), + TYPE_ID_ARRAY => JsonUnionValue::Array(self.array.value(idx)), + TYPE_ID_OBJECT => JsonUnionValue::Object(self.object.value(idx)), + _ => panic!("Invalid type_id: {type_id}, not a valid JSON type"), + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_json_union() { + let v = vec![ + Some(JsonUnionField::JsonNull), + Some(JsonUnionField::Bool(true)), + Some(JsonUnionField::Bool(false)), + Some(JsonUnionField::Int(42)), + Some(JsonUnionField::Float(42.0)), + Some(JsonUnionField::Str("foo".to_string())), + Some(JsonUnionField::Array("[42]".to_string())), + Some(JsonUnionField::Object(r#"{"foo": 42}"#.to_string())), + None, + ]; + let length = v.len(); + let json_union = JsonUnion::from_iter(v); + + let union_array = UnionArray::try_from(json_union).unwrap(); + let encoder = JsonUnionEncoder::from_union(union_array).unwrap(); + + let values_after: Vec<_> = (0..length).map(|idx| encoder.get_value(idx)).collect(); + assert_eq!( + values_after, + vec![ + JsonUnionValue::JsonNull, + JsonUnionValue::Bool(true), + JsonUnionValue::Bool(false), + JsonUnionValue::Int(42), + JsonUnionValue::Float(42.0), + JsonUnionValue::Str("foo"), + JsonUnionValue::Array("[42]"), + JsonUnionValue::Object(r#"{"foo": 42}"#), + JsonUnionValue::JsonNull, + ] + ); + } +} diff --git a/src/lib.rs b/src/lib.rs index 75b18f6..692478e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -19,6 +19,8 @@ mod json_get_str; mod json_length; mod rewrite; +pub use common_union::{JsonUnionEncoder, JsonUnionValue}; + pub mod functions { pub use crate::json_as_text::json_as_text; pub use crate::json_contains::json_contains; From ab565afada633f702c4d88f4b60a00190f9cac16 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Fri, 27 Sep 2024 01:01:14 +0100 Subject: [PATCH 2/2] using encoding --- src/common_union.rs | 36 ++++++++++++++++++++---------------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/src/common_union.rs b/src/common_union.rs index 9538685..871dbc0 100644 --- a/src/common_union.rs +++ b/src/common_union.rs @@ -211,17 +211,6 @@ pub struct JsonUnionEncoder { type_ids: ScalarBuffer, } -#[derive(Debug, PartialEq)] -pub enum JsonUnionValue<'a> { - JsonNull, - Bool(bool), - Int(i64), - Float(f64), - Str(&'a str), - Array(&'a str), - Object(&'a str), -} - impl JsonUnionEncoder { #[must_use] pub fn from_union(union: UnionArray) -> Option { @@ -241,6 +230,12 @@ impl JsonUnionEncoder { } } + #[must_use] + #[allow(clippy::len_without_is_empty)] + pub fn len(&self) -> usize { + self.type_ids.len() + } + /// Get the encodable value for a given index /// /// # Panics @@ -262,13 +257,24 @@ impl JsonUnionEncoder { } } +#[derive(Debug, PartialEq)] +pub enum JsonUnionValue<'a> { + JsonNull, + Bool(bool), + Int(i64), + Float(f64), + Str(&'a str), + Array(&'a str), + Object(&'a str), +} + #[cfg(test)] mod test { use super::*; #[test] fn test_json_union() { - let v = vec![ + let json_union = JsonUnion::from_iter(vec![ Some(JsonUnionField::JsonNull), Some(JsonUnionField::Bool(true)), Some(JsonUnionField::Bool(false)), @@ -278,14 +284,12 @@ mod test { Some(JsonUnionField::Array("[42]".to_string())), Some(JsonUnionField::Object(r#"{"foo": 42}"#.to_string())), None, - ]; - let length = v.len(); - let json_union = JsonUnion::from_iter(v); + ]); let union_array = UnionArray::try_from(json_union).unwrap(); let encoder = JsonUnionEncoder::from_union(union_array).unwrap(); - let values_after: Vec<_> = (0..length).map(|idx| encoder.get_value(idx)).collect(); + let values_after: Vec<_> = (0..encoder.len()).map(|idx| encoder.get_value(idx)).collect(); assert_eq!( values_after, vec![