Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BEAM-13960] Add support for more types when converting from between row and proto #16875

Merged
merged 2 commits into from
Mar 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion model/pipeline/src/main/proto/schema.proto
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ message LogicalType {
message Option {
// REQUIRED. Identifier for the option.
string name = 1;
// REQUIRED. Type specifer for the structure of value.
// REQUIRED. Type specifier for the structure of value.
// Conventionally, options that don't require additional configuration should
// use a boolean type, with the value set to true.
FieldType type = 2;
Expand All @@ -125,6 +125,7 @@ message Row {
}

message FieldValue {
// If none of these are set, value is considered null.
oneof field_value {
AtomicTypeValue atomic_value = 1;
ArrayTypeValue array_value = 2;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@
*/
package org.apache.beam.sdk.schemas;

import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.PipedInputStream;
import java.io.PipedOutputStream;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
Expand All @@ -28,6 +33,7 @@
import org.apache.beam.model.pipeline.v1.SchemaApi.AtomicTypeValue;
import org.apache.beam.model.pipeline.v1.SchemaApi.FieldValue;
import org.apache.beam.model.pipeline.v1.SchemaApi.IterableTypeValue;
import org.apache.beam.model.pipeline.v1.SchemaApi.LogicalTypeValue;
import org.apache.beam.model.pipeline.v1.SchemaApi.MapTypeEntry;
import org.apache.beam.model.pipeline.v1.SchemaApi.MapTypeValue;
import org.apache.beam.sdk.annotations.Experimental;
Expand All @@ -45,6 +51,7 @@
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.ByteStreams;
import org.checkerframework.checker.nullness.qual.Nullable;

/** Utility methods for translating schemas. */
Expand Down Expand Up @@ -319,6 +326,7 @@ private static FieldType fieldTypeFromProtoWithoutNullable(SchemaApi.FieldType p
fieldTypeFromProto(protoFieldType.getMapType().getValueType()));
case LOGICAL_TYPE:
String urn = protoFieldType.getLogicalType().getUrn();
SchemaApi.LogicalType logicalType = protoFieldType.getLogicalType();
Class<? extends LogicalType<?, ?>> logicalTypeClass = STANDARD_LOGICAL_TYPES.get(urn);
if (logicalTypeClass != null) {
try {
Expand Down Expand Up @@ -351,22 +359,21 @@ private static FieldType fieldTypeFromProtoWithoutNullable(SchemaApi.FieldType p
return FieldType.logicalType(
(LogicalType)
SerializableUtils.deserializeFromByteArray(
protoFieldType.getLogicalType().getPayload().toByteArray(), "logicalType"));
logicalType.getPayload().toByteArray(), "logicalType"));
} else {
@Nullable FieldType argumentType = null;
@Nullable Object argumentValue = null;
if (protoFieldType.getLogicalType().hasArgumentType()) {
argumentType = fieldTypeFromProto(protoFieldType.getLogicalType().getArgumentType());
argumentValue =
fieldValueFromProto(argumentType, protoFieldType.getLogicalType().getArgument());
if (logicalType.hasArgumentType()) {
argumentType = fieldTypeFromProto(logicalType.getArgumentType());
argumentValue = fieldValueFromProto(argumentType, logicalType.getArgument());
}
return FieldType.logicalType(
new UnknownLogicalType(
urn,
protoFieldType.getLogicalType().getPayload().toByteArray(),
logicalType.getPayload().toByteArray(),
argumentType,
argumentValue,
fieldTypeFromProto(protoFieldType.getLogicalType().getRepresentation())));
fieldTypeFromProto(logicalType.getRepresentation())));
}
default:
throw new IllegalArgumentException(
Expand All @@ -393,6 +400,14 @@ public static Object rowFromProto(SchemaApi.Row row, FieldType fieldType) {

static SchemaApi.FieldValue fieldValueToProto(FieldType fieldType, Object value) {
FieldValue.Builder builder = FieldValue.newBuilder();
if (value == null) {
if (fieldType.getNullable()) {
return builder.build();
} else {
throw new RuntimeException("Null value found for field that doesn't support nulls.");
}
}

switch (fieldType.getTypeName()) {
case ARRAY:
return builder
Expand All @@ -411,26 +426,74 @@ static SchemaApi.FieldValue fieldValueToProto(FieldType fieldType, Object value)
.build();
case ROW:
return builder.setRowValue(rowToProto((Row) value)).build();
case DATETIME:
return builder
.setLogicalTypeValue(logicalTypeToProto(FieldType.INT64, fieldType, value))
.build();
case DECIMAL:
return builder
.setLogicalTypeValue(logicalTypeToProto(FieldType.BYTES, fieldType, value))
.build();
case LOGICAL_TYPE:
return builder
.setLogicalTypeValue(logicalTypeToProto(fieldType.getLogicalType(), value))
.build();
default:
return builder.setAtomicValue(primitiveRowFieldToProto(fieldType, value)).build();
}
}

/** Returns if the given field is null and throws exception if it is and can't be. */
static boolean isNullFieldValueFromProto(FieldType fieldType, boolean hasNonNullValue) {
if (!hasNonNullValue && !fieldType.getNullable()) {
throw new RuntimeException("FieldTypeValue has no value but the field cannot be null.");
}
return !hasNonNullValue;
}

static Object fieldValueFromProto(FieldType fieldType, SchemaApi.FieldValue value) {
switch (fieldType.getTypeName()) {
case ARRAY:
if (isNullFieldValueFromProto(fieldType, value.hasArrayValue())) {
return null;
}
return arrayValueFromProto(fieldType.getCollectionElementType(), value.getArrayValue());
case ITERABLE:
if (isNullFieldValueFromProto(fieldType, value.hasIterableValue())) {
return null;
}
return iterableValueFromProto(
fieldType.getCollectionElementType(), value.getIterableValue());
case MAP:
if (isNullFieldValueFromProto(fieldType, value.hasMapValue())) {
return null;
}
return mapFromProto(
fieldType.getMapKeyType(), fieldType.getMapValueType(), value.getMapValue());
case ROW:
if (isNullFieldValueFromProto(fieldType, value.hasRowValue())) {
return null;
}
return rowFromProto(value.getRowValue(), fieldType);
case LOGICAL_TYPE:
if (isNullFieldValueFromProto(fieldType, value.hasLogicalTypeValue())) {
return null;
}
return logicalTypeFromProto(fieldType.getLogicalType(), value);
case DATETIME:
if (isNullFieldValueFromProto(fieldType, value.hasLogicalTypeValue())) {
return null;
}
return logicalTypeFromProto(FieldType.INT64, fieldType, value.getLogicalTypeValue());
case DECIMAL:
if (isNullFieldValueFromProto(fieldType, value.hasLogicalTypeValue())) {
return null;
}
return logicalTypeFromProto(FieldType.BYTES, fieldType, value.getLogicalTypeValue());
default:
if (isNullFieldValueFromProto(fieldType, value.hasAtomicValue())) {
return null;
}
return primitiveFromProto(fieldType, value.getAtomicValue());
}
}
Expand Down Expand Up @@ -485,6 +548,74 @@ private static Object mapFromProto(
entry -> fieldValueFromProto(mapValueType, entry.getValue())));
}

/** Converts logical type value from proto using a default type coder. */
private static Object logicalTypeFromProto(
FieldType baseType, FieldType inputType, LogicalTypeValue value) {
try {
PipedInputStream in = new PipedInputStream();
DataOutputStream stream = new DataOutputStream(new PipedOutputStream(in));
switch (baseType.getTypeName()) {
case INT64:
stream.writeLong(value.getValue().getAtomicValue().getInt64());
break;
case BYTES:
stream.write(value.getValue().getAtomicValue().getBytes().toByteArray());
break;
default:
throw new UnsupportedOperationException(
"Unsupported underlying type for parsing logical type via coder.");
}
stream.close();
return SchemaCoderHelpers.coderForFieldType(inputType).decode(in);
} catch (IOException e) {
throw new RuntimeException(e);
}
}

/** Converts logical type value to a proto using a default type coder. */
private static LogicalTypeValue logicalTypeToProto(
FieldType baseType, FieldType inputType, Object value) {
try {
PipedInputStream in = new PipedInputStream();
PipedOutputStream out = new PipedOutputStream(in);
SchemaCoderHelpers.coderForFieldType(inputType).encode(value, out);
out.close(); // Close required for toByteArray.
Object baseObject;
switch (baseType.getTypeName()) {
case INT64:
baseObject = new DataInputStream(in).readLong();
break;
case BYTES:
baseObject = ByteStreams.toByteArray(in);
break;
default:
throw new UnsupportedOperationException(
"Unsupported underlying type for producing LogicalType via coder.");
}
return LogicalTypeValue.newBuilder()
.setValue(fieldValueToProto(baseType, baseObject))
.build();
} catch (IOException e) {
throw new RuntimeException(e);
}
}

private static LogicalTypeValue logicalTypeToProto(LogicalType logicalType, Object value) {
return LogicalTypeValue.newBuilder()
.setValue(
fieldValueToProto(
logicalType.getBaseType(), SchemaUtils.toLogicalBaseType(logicalType, value)))
.build();
}

private static Object logicalTypeFromProto(
LogicalType logicalType, SchemaApi.FieldValue logicalValue) {
return SchemaUtils.toLogicalInputType(
logicalType,
fieldValueFromProto(
logicalType.getBaseType(), logicalValue.getLogicalTypeValue().getValue()));
}

private static AtomicTypeValue primitiveRowFieldToProto(FieldType fieldType, Object value) {
switch (fieldType.getTypeName()) {
case BYTE:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.beam.sdk.schemas;

import org.apache.beam.sdk.schemas.Schema.FieldType;
import org.apache.beam.sdk.schemas.Schema.LogicalType;

/** A set of utility functions for schemas. */
@SuppressWarnings({
Expand Down Expand Up @@ -101,4 +102,24 @@ static FieldType widenNullableTypes(FieldType fieldType1, FieldType fieldType2)
}
return result.withNullable(fieldType1.getNullable() || fieldType2.getNullable());
}

/**
* Returns the base type given a logical type and the input type.
*
* <p>This function can be used to handle logical types without knowing InputT or BaseT.
*/
public static <InputT, BaseT> BaseT toLogicalBaseType(
LogicalType<InputT, BaseT> logicalType, InputT inputType) {
return logicalType.toBaseType(inputType);
}

/**
* Returns the input type given a logical type and the base type.
*
* <p>This function can be used to handle logical types without knowing InputT or BaseT.
*/
public static <BaseT, InputT> InputT toLogicalInputType(
LogicalType<InputT, BaseT> logicalType, BaseT baseType) {
return logicalType.toInputType(baseType);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.Schema.FieldType;
import org.apache.beam.sdk.schemas.Schema.TypeName;
import org.apache.beam.sdk.schemas.SchemaUtils;
import org.apache.beam.sdk.values.RowUtils.CapturingRowCases;
import org.apache.beam.sdk.values.RowUtils.FieldOverride;
import org.apache.beam.sdk.values.RowUtils.FieldOverrides;
Expand Down Expand Up @@ -460,7 +461,10 @@ public static boolean deepEquals(Object a, Object b, Schema.FieldType fieldType)
if (a == null || b == null) {
return a == b;
} else if (fieldType.getTypeName() == TypeName.LOGICAL_TYPE) {
return deepEquals(a, b, fieldType.getLogicalType().getBaseType());
return deepEquals(
SchemaUtils.toLogicalBaseType(fieldType.getLogicalType(), a),
SchemaUtils.toLogicalBaseType(fieldType.getLogicalType(), b),
fieldType.getLogicalType().getBaseType());
} else if (fieldType.getTypeName() == Schema.TypeName.BYTES) {
return Arrays.equals((byte[]) a, (byte[]) b);
} else if (fieldType.getTypeName() == TypeName.ARRAY) {
Expand Down Expand Up @@ -598,6 +602,9 @@ public String toString(boolean includeFieldNames) {
}

private String toString(Schema.FieldType fieldType, Object value, boolean includeFieldNames) {
if (value == null) {
return "<null>";
}
StringBuilder builder = new StringBuilder();
switch (fieldType.getTypeName()) {
case ARRAY:
Expand Down
Loading