Skip to content

Commit

Permalink
DRILL-7928: Add fourth parameter for split_part udf
Browse files Browse the repository at this point in the history
  • Loading branch information
Leon-WTF authored and laurentgo committed Jun 1, 2021
1 parent 37abb0a commit ad3f344
Show file tree
Hide file tree
Showing 2 changed files with 175 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,10 @@ public void eval() {

}


/**
* Return the string part at index after splitting the input string using the
* specified delimiter. The index must be a positive integer.
*/
@FunctionTemplate(name = "split_part", scope = FunctionScope.SIMPLE, nulls = NullHandling.NULL_IF_NULL,
outputWidthCalculatorType = OutputWidthCalculatorType.CUSTOM_FIXED_WIDTH_DEFAULT)
public static class SplitPart implements DrillSimpleFunc {
Expand All @@ -405,10 +408,6 @@ public static class SplitPart implements DrillSimpleFunc {

@Override
public void setup() {
if (index.value < 1) {
throw org.apache.drill.common.exceptions.UserException.functionError()
.message("Index in split_part must be positive, value provided was " + index.value).build();
}
String split = org.apache.drill.exec.expr.fn.impl.StringFunctionHelpers.
toStringFromUTF8(delimiter.start, delimiter.end, delimiter.buffer);
splitter = com.google.common.base.Splitter.on(split);
Expand All @@ -417,8 +416,13 @@ public void setup() {

@Override
public void eval() {
String inputString =
org.apache.drill.exec.expr.fn.impl.StringFunctionHelpers.toStringFromUTF8(in.start, in.end, in.buffer);
if (index.value < 1) {
throw org.apache.drill.common.exceptions.UserException.functionError()
.message("Index in split_part must be positive, value provided was "
+ index.value).build();
}
String inputString = org.apache.drill.exec.expr.fn.impl.
StringFunctionHelpers.getStringFromVarCharHolder(in);
int arrayIndex = index.value - 1;
String result =
(String) com.google.common.collect.Iterables.get(splitter.split(inputString), arrayIndex, "");
Expand All @@ -432,6 +436,74 @@ public void eval() {

}

/**
* Return the string part from start to end after splitting the input string
* using the specified delimiter. The start must be a positive integer. The
* end is included and must be greater than or equal to the start index.
*/
@FunctionTemplate(name = "split_part", scope = FunctionScope.SIMPLE, nulls =
NullHandling.NULL_IF_NULL, outputWidthCalculatorType =
OutputWidthCalculatorType.CUSTOM_FIXED_WIDTH_DEFAULT)
public static class SplitPartStartEnd implements DrillSimpleFunc {
@Param
VarCharHolder in;
@Param
VarCharHolder delimiter;
@Param
IntHolder start;
@Param
IntHolder end;

@Workspace
com.google.common.base.Splitter splitter;

@Workspace
com.google.common.base.Joiner joiner;

@Inject
DrillBuf buffer;

@Output
VarCharHolder out;

@Override
public void setup() {
String split = org.apache.drill.exec.expr.fn.impl.StringFunctionHelpers.
toStringFromUTF8(delimiter.start, delimiter.end, delimiter.buffer);
splitter = com.google.common.base.Splitter.on(split);
joiner = com.google.common.base.Joiner.on(split);
}

@Override
public void eval() {
if (start.value < 1) {
throw org.apache.drill.common.exceptions.UserException.functionError()
.message("Start in split_part must be positive, value provided was "
+ start.value).build();
}
if (end.value < start.value) {
throw org.apache.drill.common.exceptions.UserException.functionError()
.message("End in split_part must be greater than or equal to start, " +
"value provided was start:" + start.value + ",end:" + end.value).build();
}
String inputString = org.apache.drill.exec.expr.fn.impl.
StringFunctionHelpers.getStringFromVarCharHolder(in);
int arrayIndex = start.value - 1;
java.util.Iterator<String> iterator = com.google.common.collect.Iterables
.limit(com.google.common.collect.Iterables.skip(splitter
.split(inputString), arrayIndex),end.value - start.value + 1)
.iterator();
byte[] strBytes = joiner.join(iterator).getBytes(
com.google.common.base.Charsets.UTF_8);

out.buffer = buffer = buffer.reallocIfNeeded(strBytes.length);
out.start = 0;
out.end = strBytes.length;
out.buffer.setBytes(0, strBytes);
}

}

// same as function "position(substr, str) ", except the reverse order of argument.
@FunctionTemplate(name = "strpos", scope = FunctionScope.SIMPLE, nulls = NullHandling.NULL_IF_NULL)
public static class Strpos implements DrillSimpleFunc {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,24 +70,6 @@ public void testSplitPart() throws Exception {
.baselineValues("rty")
.go();

// invalid index
boolean expectedErrorEncountered;
try {
testBuilder()
.sqlQuery("select split_part('abc~@~def~@~ghi', '~@~', 0) res1 from (values(1))")
.ordered()
.baselineColumns("res1")
.baselineValues("abc")
.go();
expectedErrorEncountered = false;
} catch (Exception ex) {
assertTrue(ex.getMessage().contains("Index in split_part must be positive, value provided was 0"));
expectedErrorEncountered = true;
}
if (!expectedErrorEncountered) {
throw new RuntimeException("Missing expected error on invalid index for split_part function");
}

// with a multi-byte splitter
testBuilder()
.sqlQuery("select split_part('abc\\u1111drill\\u1111ghi', '\\u1111', 2) res1 from (values(1))")
Expand All @@ -113,6 +95,102 @@ public void testSplitPart() throws Exception {
.go();
}

@Test
public void testSplitPartStartEnd() throws Exception {
testBuilder()
.sqlQuery("select split_part(a, '~@~', 1, 2) res1 from (" +
"values('abc~@~def~@~ghi'), ('qwe~@~rty~@~uio')) as t(a)")
.ordered()
.baselineColumns("res1")
.baselineValues("abc~@~def")
.baselineValues("qwe~@~rty")
.go();

testBuilder()
.sqlQuery("select split_part(a, '~@~', 2, 3) res1 from (" +
"values('abc~@~def~@~ghi'), ('qwe~@~rty~@~uio')) as t(a)")
.ordered()
.baselineColumns("res1")
.baselineValues("def~@~ghi")
.baselineValues("rty~@~uio")
.go();

// with a multi-byte splitter
testBuilder()
.sqlQuery("select split_part('abc\\u1111drill\\u1111ghi', '\\u1111', 2, 2) " +
"res1 from (values(1))")
.ordered()
.baselineColumns("res1")
.baselineValues("drill")
.go();

// start index going beyond the last available index, returns empty string
testBuilder()
.sqlQuery("select split_part('a,b,c', ',', 4, 5) res1 from (values(1))")
.ordered()
.baselineColumns("res1")
.baselineValues("")
.go();

// end index going beyond the last available index, returns remaining string
testBuilder()
.sqlQuery("select split_part('a,b,c', ',', 1, 10) res1 from (values(1))")
.ordered()
.baselineColumns("res1")
.baselineValues("a,b,c")
.go();

// if the delimiter does not appear in the string, 1 returns the whole string
testBuilder()
.sqlQuery("select split_part('a,b,c', ' ', 1, 2) res1 from (values(1))")
.ordered()
.baselineColumns("res1")
.baselineValues("a,b,c")
.go();
}

@Test
public void testInvalidSplitPartParameters() {
boolean expectedErrorEncountered;
try {
testBuilder()
.sqlQuery("select split_part('abc~@~def~@~ghi', '~@~', 0) res1 from " +
"(values(1))")
.ordered()
.baselineColumns("res1")
.baselineValues("abc")
.go();
expectedErrorEncountered = false;
} catch (Exception ex) {
assertTrue(ex.getMessage().contains("Index in split_part must be positive, " +
"value provided was 0"));
expectedErrorEncountered = true;
}
if (!expectedErrorEncountered) {
throw new RuntimeException("Missing expected error on invalid index for " +
"split_part function");
}

try {
testBuilder()
.sqlQuery("select split_part('abc~@~def~@~ghi', '~@~', 2, 1) res1 from " +
"(values(1))")
.ordered()
.baselineColumns("res1")
.baselineValues("abc")
.go();
expectedErrorEncountered = false;
} catch (Exception ex) {
assertTrue(ex.getMessage().contains("End in split_part must be greater " +
"than or equal to start"));
expectedErrorEncountered = true;
}
if (!expectedErrorEncountered) {
throw new RuntimeException("Missing expected error on invalid index for " +
"split_part function");
}
}

@Test
public void testRegexpMatches() throws Exception {
testBuilder()
Expand Down

0 comments on commit ad3f344

Please sign in to comment.