Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add spark submodule to convert SparkSql's LogicalPlan to Substrait Rel. #90

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion settings.gradle.kts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
rootProject.name = "substrait"

include("bom", "core", "isthmus")
include("bom", "core", "isthmus", "spark")

pluginManagement {
plugins {
Expand Down
Empty file added spark/README.md
Empty file.
72 changes: 72 additions & 0 deletions spark/build.gradle.kts
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
plugins {
`maven-publish`
id("java")
id("scala")
id("idea")
// id("com.palantir.graal") version "0.10.0"
id("com.diffplug.spotless") version "6.5.1"
}

publishing { publications { create<MavenPublication>("maven") { from(components["java"]) } } }

java { toolchain { languageVersion.set(JavaLanguageVersion.of(17)) } }

dependencies {
implementation(project(":core")) {
exclude("org.slf4j", "slf4j-jdk14")
exclude("org.antlr", "antlr4")
}
// spark
implementation("org.scala-lang:scala-library:2.12.16")
testImplementation("org.scalatest:scalatest_2.12:3.3.0-SNAP3")

implementation("org.apache.spark:spark-sql_2.12:3.3.0")
testImplementation("org.apache.spark:spark-hive_2.12:3.3.0")

// testImplementation("org.apache.spark:spark-sql_2.12:3.3.0:tests")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove?

// testImplementation("org.apache.spark:spark-core_2.12:3.3.0:tests")
// testImplementation("org.apache.spark:spark-catalyst_2.12:3.3.0:tests")

// iceberg-spark
testImplementation("org.testcontainers:testcontainers:1.17.3")
testImplementation("org.testcontainers:junit-jupiter:1.17.3")
testImplementation("org.apache.iceberg:iceberg-spark-runtime-3.3_2.12:0.14.0")

testImplementation("org.junit.jupiter:junit-jupiter:5.7.0")
implementation("org.reflections:reflections:0.9.12")
implementation("com.google.guava:guava:29.0-jre")
// implementation("org.graalvm.sdk:graal-sdk:22.0.0.2")
// implementation("info.picocli:picocli:4.6.1")
implementation("com.google.protobuf:protobuf-java-util:3.17.3") {
exclude("com.google.guava", "guava")
.because("Brings in Guava for Android, which we don't want (and breaks multimaps).")
}
implementation("com.google.code.findbugs:jsr305:3.0.2")
implementation("com.github.ben-manes.caffeine:caffeine:3.0.4")
implementation("org.immutables:value-annotations:2.8.8")
}

tasks {
test {
jvmArgs(
"--add-opens",
"java.base/sun.nio.ch=ALL-UNNAMED",
"--add-opens",
"java.base/sun.nio.cs=ALL-UNNAMED",
"--add-opens",
"java.base/java.lang=ALL-UNNAMED",
"--add-opens",
"java.base/java.io=ALL-UNNAMED",
"--add-opens",
"java.base/java.net=ALL-UNNAMED",
"--add-opens",
"java.base/java.nio=ALL-UNNAMED",
"--add-opens",
"java.base/java.util=ALL-UNNAMED",
"--add-opens",
"java.base/sun.security.action=ALL-UNNAMED",
"--add-opens",
"java.base/sun.util.calendar=ALL-UNNAMED"
)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package io.substrait.spark

import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan

object SparkExpressionConverter {

def convert(expr: Expression,
inputPlan: LogicalPlan,
inputRecordType: io.substrait.`type`.Type.Struct): io.substrait.expression.Expression = expr match {
case _ =>
throw new UnsupportedOperationException("Unable to convert the expr to a substrait Expression: " + expr)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package io.substrait.spark

import io.substrait.relation.NamedScan
import org.apache.spark.sql.catalyst.catalog.HiveTableRelation
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanRelation}
import org.apache.spark.sql.types.StructType

import scala.collection.JavaConverters._

object SparkLogicalPlanConverter {

def convert(plan: LogicalPlan): io.substrait.relation.Rel = plan match {
case _: Project =>
convertProject(plan)
case _: LogicalRelation | _: DataSourceV2ScanRelation | _: HiveTableRelation =>
convertReadOperator(plan)
case _ =>
throw new UnsupportedOperationException("Unable to convert the plan to a substrait rel: " + plan)

}

/**
* Project Operator: https://substrait.io/relations/logical_relations/#project-operation
*
* @param plan
* @return
*/
def convertProject(plan: LogicalPlan): io.substrait.relation.Project = plan match {
case project: Project =>
val childRel = SparkLogicalPlanConverter.convert(project.child);
val exprList = project.projectList.map(expr => SparkExpressionConverter.convert(expr, project.child, childRel.getRecordType)).asJava
val projectRel = io.substrait.relation.Project.builder
.expressions(exprList)
.input(childRel)
.build
projectRel
}

def buildNamedScan(schema: StructType, tableNames: List[String]): NamedScan = {
val namedStruct = SparkTypeConverter.toNamedStruct(schema)
val namedScan = NamedScan.builder.initialSchema(namedStruct).addAllNames(tableNames.asJava).build
namedScan
}

/**
* Read Operator: https://substrait.io/relations/logical_relations/#read-operator
*
* @param plan
* @return
*/
def convertReadOperator(plan: LogicalPlan): io.substrait.relation.AbstractReadRel = {
var schema: StructType = null
var tableNames: List[String] = null;
plan match {
case logicalRelation: LogicalRelation =>
schema = logicalRelation.schema
tableNames = logicalRelation.catalogTable.get.identifier.unquotedString.split("\\.").toList
buildNamedScan(schema, tableNames)
case dataSourceV2ScanRelation: DataSourceV2ScanRelation =>
schema = dataSourceV2ScanRelation.schema
tableNames = dataSourceV2ScanRelation.relation.identifier.get.toString.split("\\.").toList
buildNamedScan(schema, tableNames)
case dataSourceV2Relation: DataSourceV2Relation =>
schema = dataSourceV2Relation.schema
tableNames = dataSourceV2Relation.identifier.get.toString.split("\\.").toList
buildNamedScan(schema, tableNames)
case hiveTableRelation: HiveTableRelation =>
schema = hiveTableRelation.schema
tableNames = hiveTableRelation.tableMeta.identifier.unquotedString.split("\\.").toList
buildNamedScan(schema, tableNames)
//TODO: LocalRelation,Range=>Virtual Table,LogicalRelation(HadoopFsRelation)=>LocalFiles

case _ =>
throw new UnsupportedOperationException("Unable to convert the plan to a substrait AbstractReadRel: " + plan)
}
Comment on lines +54 to +77
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
var schema: StructType = null
var tableNames: List[String] = null;
plan match {
case logicalRelation: LogicalRelation =>
schema = logicalRelation.schema
tableNames = logicalRelation.catalogTable.get.identifier.unquotedString.split("\\.").toList
buildNamedScan(schema, tableNames)
case dataSourceV2ScanRelation: DataSourceV2ScanRelation =>
schema = dataSourceV2ScanRelation.schema
tableNames = dataSourceV2ScanRelation.relation.identifier.get.toString.split("\\.").toList
buildNamedScan(schema, tableNames)
case dataSourceV2Relation: DataSourceV2Relation =>
schema = dataSourceV2Relation.schema
tableNames = dataSourceV2Relation.identifier.get.toString.split("\\.").toList
buildNamedScan(schema, tableNames)
case hiveTableRelation: HiveTableRelation =>
schema = hiveTableRelation.schema
tableNames = hiveTableRelation.tableMeta.identifier.unquotedString.split("\\.").toList
buildNamedScan(schema, tableNames)
//TODO: LocalRelation,Range=>Virtual Table,LogicalRelation(HadoopFsRelation)=>LocalFiles
case _ =>
throw new UnsupportedOperationException("Unable to convert the plan to a substrait AbstractReadRel: " + plan)
}
val (schema, tableNames) =
plan match {
case logicalRelation: LogicalRelation =>
(logicalRelation.schema, logicalRelation.catalogTable.get.identifier.unquotedString.split("\\.").toList)
case dataSourceV2ScanRelation: DataSourceV2ScanRelation =>
(dataSourceV2ScanRelation.schema, dataSourceV2ScanRelation.relation.identifier.get.toString.split("\\.").toList)
case dataSourceV2Relation: DataSourceV2Relation =>
(dataSourceV2Relation.schema, dataSourceV2Relation.identifier.get.toString.split("\\.").toList)
case hiveTableRelation: HiveTableRelation =>
(hiveTableRelation.schema, hiveTableRelation.tableMeta.identifier.unquotedString.split("\\.").toList)
//TODO: LocalRelation,Range=>Virtual Table,LogicalRelation(HadoopFsRelation)=>LocalFiles
case _ =>
throw new UnsupportedOperationException("Unable to convert the plan to a substrait AbstractReadRel: " + plan)
}
buildNamedScan(schema, tableNames)


}
}
62 changes: 62 additions & 0 deletions spark/src/main/scala/io/substrait/spark/SparkTypeConverter.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package io.substrait.spark

import io.substrait.`type`.{NamedStruct, Type}
import org.apache.spark.sql.types._
import org.json4s.scalap.scalasig.ClassFileParser.field

object SparkTypeConverter {

def toNamedStruct(schema: StructType): io.substrait.`type`.NamedStruct = {
val creator = Type.withNullability(true)
val names = new java.util.ArrayList[String]
val children = new java.util.ArrayList[Type]
schema.fields.foreach(field => {
names.add(field.name)
children.add(convert(field.dataType, field.nullable))
})
val struct = creator.struct(children)
NamedStruct.of(names, struct)
Comment on lines +10 to +18
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
val creator = Type.withNullability(true)
val names = new java.util.ArrayList[String]
val children = new java.util.ArrayList[Type]
schema.fields.foreach(field => {
names.add(field.name)
children.add(convert(field.dataType, field.nullable))
})
val struct = creator.struct(children)
NamedStruct.of(names, struct)
import scala.jdk.CollectionConverters._
val name = schema.fields.map(_.name).asJava
val struct = creator.struct(
schema.fields.map(field => convert(field.dataType, field.nullable)).asJava
)
NamedStruct.of(names, struct)

}

def convert(dataType: DataType, nullable: Boolean): io.substrait.`type`.Type = {
val creator = Type.withNullability(nullable)
//spark sql data types: https://spark.apache.org/docs/latest/sql-ref-datatypes.html
dataType match {
case ByteType => creator.I8
case ShortType => creator.I16
case IntegerType => creator.I32
case LongType => creator.I64
case FloatType => creator.FP32
case DoubleType => creator.FP64
case decimalType: DecimalType =>
if (decimalType.precision > 38) {
throw new UnsupportedOperationException("unsupported decimal precision " + decimalType.precision)
}
creator.decimal(decimalType.precision, decimalType.scale);
case StringType => creator.STRING
case BinaryType => creator.BINARY
case BooleanType => creator.BOOLEAN
case TimestampType => creator.TIMESTAMP
case DateType => creator.DATE
case YearMonthIntervalType.DEFAULT => creator.INTERVAL_YEAR
case DayTimeIntervalType.DEFAULT => creator.INTERVAL_DAY
case ArrayType(elementType, containsNull) =>
creator.list(convert(elementType, containsNull))
case MapType(keyType, valueType, valueContainsNull) =>
creator.map(convert(keyType, nullable = false), convert(valueType, valueContainsNull))
case StructType(fields) =>
// TODO: now we lost the nested StructType's field names,do we need them?
//val names = new java.util.ArrayList[String]
val children = new java.util.ArrayList[Type]
fields.foreach(field => {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

map

//names.add(field.name)
children.add(convert(field.dataType, field.nullable))
})
val struct = creator.struct(children)
struct
case _ =>
throw new UnsupportedOperationException("Unable to convert the type " + field.toString)
}
}

}
28 changes: 28 additions & 0 deletions spark/src/test/java/io/substrait/spark/BaseSparkSqlPlanTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package io.substrait.spark;

import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
import org.junit.jupiter.api.AfterAll;

public class BaseSparkSqlPlanTest {

protected static SparkSession spark;

@AfterAll
public static void afterAll() {
if (spark != null) {
spark.stop();
}
}

protected static Dataset<Row> sql(String sql) {
System.out.println(sql);
return spark.sql(sql);
}

protected static LogicalPlan plan(String sql) {
return sql(sql).queryExecution().optimizedPlan();
}
}
56 changes: 56 additions & 0 deletions spark/src/test/java/io/substrait/spark/SparkSqlPlanTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package io.substrait.spark;

import java.io.File;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import org.apache.commons.io.FileUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;

public class SparkSqlPlanTest extends BaseSparkSqlPlanTest {

public static void prepareSparkTables(SparkSession spark) throws IOException {
File localWareHouseDir = new File("spark-warehouse");
if (localWareHouseDir.exists()) {
FileUtils.deleteDirectory(localWareHouseDir);
}
FileUtils.forceMkdir(localWareHouseDir);
spark.sql("DROP DATABASE IF EXISTS tpch CASCADE");
spark.sql("CREATE DATABASE IF NOT EXISTS tpch");
spark.sql("use tpch");
String tpchCreateTableString =
FileUtils.readFileToString(
new File("src/test/resources/tpch_schema.sql"), StandardCharsets.UTF_8);
Arrays.stream(tpchCreateTableString.split(";"))
.filter(StringUtils::isNotBlank)
.toList()
.forEach(spark::sql);
spark.sql("show tables").show();
}

@BeforeAll
public static void beforeAll() {
spark =
SparkSession.builder()
.master("local[2]")
.config("spark.sql.legacy.createHiveTableByDefault", "false")
.getOrCreate();
try {
prepareSparkTables(spark);
} catch (IOException e) {
Assertions.fail(e);
}
}

@Test
public void testReadRel() {
LogicalPlan plan = plan("select * from lineitem");
System.out.println(plan.treeString());
System.out.println(SparkLogicalPlanConverter.convert(plan));
}
}
Loading