Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Union encoding #49

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 111 additions & 4 deletions src/common_union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ use std::sync::{Arc, OnceLock};
use datafusion::arrow::array::{
Array, ArrayRef, BooleanArray, Float64Array, Int64Array, NullArray, StringArray, UnionArray,
};
use datafusion::arrow::buffer::Buffer;
use datafusion::arrow::buffer::{Buffer, ScalarBuffer};
use datafusion::arrow::datatypes::{DataType, Field, UnionFields, UnionMode};
use datafusion::arrow::error::ArrowError;
use datafusion::common::ScalarValue;

pub(crate) fn is_json_union(data_type: &DataType) -> bool {
pub fn is_json_union(data_type: &DataType) -> bool {
match data_type {
DataType::Union(fields, UnionMode::Sparse) => fields == &union_fields(),
_ => false,
Expand Down Expand Up @@ -64,7 +65,7 @@ impl JsonUnion {
strings: vec![None; length],
arrays: vec![None; length],
objects: vec![None; length],
type_ids: vec![0; length],
type_ids: vec![TYPE_ID_NULL; length],
index: 0,
length,
}
Expand Down Expand Up @@ -114,7 +115,7 @@ impl FromIterator<Option<JsonUnionField>> for JsonUnion {
}

impl TryFrom<JsonUnion> for UnionArray {
type Error = datafusion::arrow::error::ArrowError;
type Error = ArrowError;

fn try_from(value: JsonUnion) -> Result<Self, Self::Error> {
let children: Vec<Arc<dyn Array>> = vec![
Expand Down Expand Up @@ -199,3 +200,109 @@ impl From<JsonUnionField> for ScalarValue {
}
}
}

pub struct JsonUnionEncoder {
boolean: BooleanArray,
int: Int64Array,
float: Float64Array,
string: StringArray,
array: StringArray,
object: StringArray,
type_ids: ScalarBuffer<i8>,
}

impl JsonUnionEncoder {
#[must_use]
pub fn from_union(union: UnionArray) -> Option<Self> {
if is_json_union(union.data_type()) {
let (_, type_ids, _, c) = union.into_parts();
Some(Self {
boolean: c[1].as_any().downcast_ref::<BooleanArray>().cloned()?,
int: c[2].as_any().downcast_ref::<Int64Array>().cloned()?,
float: c[3].as_any().downcast_ref::<Float64Array>().cloned()?,
string: c[4].as_any().downcast_ref::<StringArray>().cloned()?,
array: c[5].as_any().downcast_ref::<StringArray>().cloned()?,
object: c[6].as_any().downcast_ref::<StringArray>().cloned()?,
Comment on lines +220 to +225
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like allowing the downcasts to fail here is unnecessary due to the is_json_union check.

Suggested change
boolean: c[1].as_any().downcast_ref::<BooleanArray>().cloned()?,
int: c[2].as_any().downcast_ref::<Int64Array>().cloned()?,
float: c[3].as_any().downcast_ref::<Float64Array>().cloned()?,
string: c[4].as_any().downcast_ref::<StringArray>().cloned()?,
array: c[5].as_any().downcast_ref::<StringArray>().cloned()?,
object: c[6].as_any().downcast_ref::<StringArray>().cloned()?,
boolean: c[1].as_boolean().clone(),
int: c[2].as_primitive().clone(),
float: c[3].as_primitive().clone(),
string: c[4].as_string().clone(),
array: c[5].as_string().clone(),
object: c[6].as_string().clone(),

type_ids,
})
} else {
None
}
}

#[must_use]
#[allow(clippy::len_without_is_empty)]
pub fn len(&self) -> usize {
self.type_ids.len()
}

/// Get the encodable value for a given index
///
/// # Panics
///
/// Panics if the idx is outside the union values or an invalid type id exists in the union.
#[must_use]
pub fn get_value(&self, idx: usize) -> JsonUnionValue {
let type_id = self.type_ids[idx];
match type_id {
TYPE_ID_NULL => JsonUnionValue::JsonNull,
TYPE_ID_BOOL => JsonUnionValue::Bool(self.boolean.value(idx)),
TYPE_ID_INT => JsonUnionValue::Int(self.int.value(idx)),
TYPE_ID_FLOAT => JsonUnionValue::Float(self.float.value(idx)),
TYPE_ID_STR => JsonUnionValue::Str(self.string.value(idx)),
TYPE_ID_ARRAY => JsonUnionValue::Array(self.array.value(idx)),
TYPE_ID_OBJECT => JsonUnionValue::Object(self.object.value(idx)),
_ => panic!("Invalid type_id: {type_id}, not a valid JSON type"),
}
}
}

#[derive(Debug, PartialEq)]
pub enum JsonUnionValue<'a> {
JsonNull,
Bool(bool),
Int(i64),
Float(f64),
Str(&'a str),
Array(&'a str),
Object(&'a str),
}

#[cfg(test)]
mod test {
use super::*;

#[test]
fn test_json_union() {
let json_union = JsonUnion::from_iter(vec![
Some(JsonUnionField::JsonNull),
Some(JsonUnionField::Bool(true)),
Some(JsonUnionField::Bool(false)),
Some(JsonUnionField::Int(42)),
Some(JsonUnionField::Float(42.0)),
Some(JsonUnionField::Str("foo".to_string())),
Some(JsonUnionField::Array("[42]".to_string())),
Some(JsonUnionField::Object(r#"{"foo": 42}"#.to_string())),
None,
]);

let union_array = UnionArray::try_from(json_union).unwrap();
let encoder = JsonUnionEncoder::from_union(union_array).unwrap();

let values_after: Vec<_> = (0..encoder.len()).map(|idx| encoder.get_value(idx)).collect();
assert_eq!(
values_after,
vec![
JsonUnionValue::JsonNull,
JsonUnionValue::Bool(true),
JsonUnionValue::Bool(false),
JsonUnionValue::Int(42),
JsonUnionValue::Float(42.0),
JsonUnionValue::Str("foo"),
JsonUnionValue::Array("[42]"),
JsonUnionValue::Object(r#"{"foo": 42}"#),
JsonUnionValue::JsonNull,
]
);
}
}
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ mod json_get_str;
mod json_length;
mod rewrite;

pub use common_union::{JsonUnionEncoder, JsonUnionValue};

pub mod functions {
pub use crate::json_as_text::json_as_text;
pub use crate::json_contains::json_contains;
Expand Down
Loading