Skip to content

Commit

Permalink
Better SQL query splitter
Browse files Browse the repository at this point in the history
  • Loading branch information
rodrigozhou committed Jan 10, 2023
1 parent 3e31880 commit 1df7ae8
Show file tree
Hide file tree
Showing 2 changed files with 229 additions and 7 deletions.
107 changes: 100 additions & 7 deletions common/persistence/query_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,24 @@
package persistence

import (
"bytes"
"fmt"
"io"
"os"
"strings"
"unicode"
)

const (
queryDelimiter = ";"
queryDelimiter = ';'
querySliceDefaultSize = 100

sqlLeftParenthesis = '('
sqlRightParenthesis = ')'
sqlBeginKeyword = "begin"
sqlEndKeyword = "end"
sqlLineComment = "--"
sqlSingleQuote = '\''
sqlDoubleQuote = '"'
)

// LoadAndSplitQuery loads and split cql / sql query into one statement per string
Expand All @@ -57,22 +66,106 @@ func LoadAndSplitQuery(
func LoadAndSplitQueryFromReaders(
readers []io.Reader,
) ([]string, error) {

result := make([]string, 0, querySliceDefaultSize)

for _, r := range readers {
content, err := io.ReadAll(r)
if err != nil {
return nil, fmt.Errorf("error reading contents: %w", err)
}
for _, stmt := range strings.Split(string(content), queryDelimiter) {
stmt = strings.TrimSpace(stmt)
n := len(content)
contentStr := string(bytes.ToLower(content))
for i, j := 0, 0; i < n; i = j {
// stack to keep track of open parenthesis/blocks
var st []byte

stmtLoop:
for ; j < n; j++ {
switch contentStr[j] {
case queryDelimiter:
if len(st) == 0 {
j += 1
break stmtLoop
}

case sqlLeftParenthesis:
st = append(st, sqlLeftParenthesis)

case sqlRightParenthesis:
if len(st) == 0 || st[len(st)-1] != sqlLeftParenthesis {
return nil, fmt.Errorf("error reading contents: unmatched right parenthesis")
}
st = st[:len(st)-1]

case sqlBeginKeyword[0]:
if matchWord(contentStr, sqlBeginKeyword, j) {
st = append(st, sqlBeginKeyword[0])
j += len(sqlBeginKeyword) - 1
}

case sqlEndKeyword[0]:
if matchWord(contentStr, sqlEndKeyword, j) {
if len(st) == 0 || st[len(st)-1] != sqlBeginKeyword[0] {
return nil, fmt.Errorf("error reading contents: unmatched `END` keyword")
}
st = st[:len(st)-1]
j += len(sqlEndKeyword) - 1
}

case sqlSingleQuote, sqlDoubleQuote:
quote := contentStr[j]
j += 1
for j < n && contentStr[j] != quote {
j += 1
}
if j == n {
return nil, fmt.Errorf("error reading contents: unmatched quotes")
}

case sqlLineComment[0]:
if j+len(sqlLineComment) <= n && contentStr[j:j+len(sqlLineComment)] == sqlLineComment {
for j < n && contentStr[j] != '\n' {
j += 1
}
}

default:
// no-op: generic character
}
}

if len(st) > 0 {
switch st[len(st)-1] {
case sqlLeftParenthesis:
return nil, fmt.Errorf("error reading contents: unmatched left parenthesis")
case sqlBeginKeyword[0]:
return nil, fmt.Errorf("error reading contents: unmatched `BEGIN` keyword")
default:
// should never enter here
return nil, fmt.Errorf("error reading contents: unmatched `%c`", st[len(st)-1])
}
}

stmt := string(bytes.TrimSpace(content[i:j]))
if stmt == "" {
continue
}
result = append(result, stmt)
}

}
return result, nil
}

// matchWord is a simple test to check if it matches the whole word:
// it checks if the adjacent charactes are not letters if they exist.
func matchWord(s, word string, pos int) bool {
if pos+len(word) > len(s) || s[pos:pos+len(word)] != word {
return false
}
if pos > 0 && unicode.IsLetter(rune(s[pos-1])) {
return false
}
if pos+len(word) < len(s) && unicode.IsLetter(rune(s[pos+len(word)])) {
return false
}
return true
}
129 changes: 129 additions & 0 deletions common/persistence/query_util_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
// The MIT License
//
// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved.
//
// Copyright (c) 2020 Uber Technologies, Inc.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.

package persistence

import (
"bytes"
"io"
"testing"

"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"

"go.temporal.io/server/common/log"
)

type (
queryUtilSuite struct {
suite.Suite
// override suite.Suite.Assertions with require.Assertions; this means that s.NotNil(nil) will stop the test,
// not merely log an error
*require.Assertions
logger log.Logger
}
)

func TestQueryUtilSuite(t *testing.T) {
s := new(queryUtilSuite)
suite.Run(t, s)
}

func (s *queryUtilSuite) SetupTest() {
s.logger = log.NewTestLogger()
// Have to define our overridden assertions in the test setup. If we did it earlier, s.T() will return nil
s.Assertions = require.New(s.T())
}

func (s *queryUtilSuite) TestLoadAndSplitQueryFromReaders() {
input := `
CREATE TABLE test (
id BIGINT not null,
col1 BIGINT, -- comment with unmatched parenthesis )
col2 VARCHAR(255),
PRIMARY KEY (id)
);
CREATE INDEX test_idx ON test (col1);
CREATE TRIGGER test_ai AFTER INSERT ON test
BEGIN
SELECT *, 'string with unmatched chars ")' FROM test;
END;
`
statements, err := LoadAndSplitQueryFromReaders([]io.Reader{bytes.NewBufferString(input)})
s.NoError(err)
s.Equal(3, len(statements))
s.Equal(
`CREATE TABLE test (
id BIGINT not null,
col1 BIGINT, -- comment with unmatched parenthesis )
col2 VARCHAR(255),
PRIMARY KEY (id)
);`,
statements[0],
)
s.Equal(`CREATE INDEX test_idx ON test (col1);`, statements[1])
s.Equal(
`CREATE TRIGGER test_ai AFTER INSERT ON test
BEGIN
SELECT *, 'string with unmatched chars ")' FROM test;
END;`,
statements[2],
)

input = "CREATE TABLE test (;"
statements, err = LoadAndSplitQueryFromReaders([]io.Reader{bytes.NewBufferString(input)})
s.Error(err, "error reading contents: unmatched left parenthesis")
s.Nil(statements)

input = "CREATE TABLE test ());"
statements, err = LoadAndSplitQueryFromReaders([]io.Reader{bytes.NewBufferString(input)})
s.Error(err, "error reading contents: unmatched right parenthesis")
s.Nil(statements)

input = "begin"
statements, err = LoadAndSplitQueryFromReaders([]io.Reader{bytes.NewBufferString(input)})
s.Error(err, "error reading contents: unmatched `BEGIN` keyword")
s.Nil(statements)

input = "end"
statements, err = LoadAndSplitQueryFromReaders([]io.Reader{bytes.NewBufferString(input)})
s.Error(err, "error reading contents: unmatched `END` keyword")
s.Nil(statements)

input = "select ' from test;"
statements, err = LoadAndSplitQueryFromReaders([]io.Reader{bytes.NewBufferString(input)})
s.Error(err, "error reading contents: unmatched quotes")
s.Nil(statements)
}

func (s *queryUtilSuite) TestMatchWord() {
s.True(matchWord("BEGIN", "BEGIN", 0))
s.True(matchWord(" BEGIN ", "BEGIN", 1))
s.True(matchWord(")BEGIN;", "BEGIN", 1))
s.False(matchWord("BEGIN", "BEGIN", 1))
s.False(matchWord("sBEGIN", "BEGIN", 1))
s.False(matchWord("BEGINs", "BEGIN", 1))
}

0 comments on commit 1df7ae8

Please sign in to comment.