From ad885d7b08b527940c5a82a2cf088ee4ca1c733a Mon Sep 17 00:00:00 2001 From: rodrigozhou Date: Mon, 23 Jan 2023 17:23:52 -0800 Subject: [PATCH] Better SQL query splitter --- common/persistence/query_util.go | 122 ++++++++++++++++++++++-- common/persistence/query_util_test.go | 131 ++++++++++++++++++++++++++ tools/common/schema/util.go | 45 +-------- 3 files changed, 247 insertions(+), 51 deletions(-) create mode 100644 common/persistence/query_util_test.go diff --git a/common/persistence/query_util.go b/common/persistence/query_util.go index da7a58771b0..2ccd5e109ba 100644 --- a/common/persistence/query_util.go +++ b/common/persistence/query_util.go @@ -25,18 +25,29 @@ package persistence import ( + "bytes" "fmt" "io" "os" "strings" + "unicode" ) const ( - queryDelimiter = ";" + queryDelimiter = ';' querySliceDefaultSize = 100 + + sqlLeftParenthesis = '(' + sqlRightParenthesis = ')' + sqlBeginKeyword = "begin" + sqlEndKeyword = "end" + sqlLineComment = "--" + sqlSingleQuote = '\'' + sqlDoubleQuote = '"' ) -// LoadAndSplitQuery loads and split cql / sql query into one statement per string +// LoadAndSplitQuery loads and split cql / sql query into one statement per string. +// Comments are removed from the query. func LoadAndSplitQuery( filePaths []string, ) ([]string, error) { @@ -53,26 +64,121 @@ func LoadAndSplitQuery( return LoadAndSplitQueryFromReaders(files) } -// LoadAndSplitQueryFromReaders loads and split cql / sql query into one statement per string +// LoadAndSplitQueryFromReaders loads and split cql / sql query into one statement per string. +// Comments are removed from the query. func LoadAndSplitQueryFromReaders( readers []io.Reader, ) ([]string, error) { - result := make([]string, 0, querySliceDefaultSize) - for _, r := range readers { content, err := io.ReadAll(r) if err != nil { return nil, fmt.Errorf("error reading contents: %w", err) } - for _, stmt := range strings.Split(string(content), queryDelimiter) { - stmt = strings.TrimSpace(stmt) + n := len(content) + contentStr := string(bytes.ToLower(content)) + for i, j := 0, 0; i < n; i = j { + // stack to keep track of open parenthesis/blocks + var st []byte + var stmtBuilder strings.Builder + + stmtLoop: + for ; j < n; j++ { + switch contentStr[j] { + case queryDelimiter: + if len(st) == 0 { + j += 1 + break stmtLoop + } + + case sqlLeftParenthesis: + st = append(st, sqlLeftParenthesis) + + case sqlRightParenthesis: + if len(st) == 0 || st[len(st)-1] != sqlLeftParenthesis { + return nil, fmt.Errorf("error reading contents: unmatched right parenthesis") + } + st = st[:len(st)-1] + + case sqlBeginKeyword[0]: + if matchWord(contentStr, sqlBeginKeyword, j) { + st = append(st, sqlBeginKeyword[0]) + j += len(sqlBeginKeyword) - 1 + } + + case sqlEndKeyword[0]: + if matchWord(contentStr, sqlEndKeyword, j) { + if len(st) == 0 || st[len(st)-1] != sqlBeginKeyword[0] { + return nil, fmt.Errorf("error reading contents: unmatched `END` keyword") + } + st = st[:len(st)-1] + j += len(sqlEndKeyword) - 1 + } + + case sqlSingleQuote, sqlDoubleQuote: + quote := contentStr[j] + j += 1 + for j < n && contentStr[j] != quote { + j += 1 + } + if j == n { + return nil, fmt.Errorf("error reading contents: unmatched quotes") + } + + case sqlLineComment[0]: + if j+len(sqlLineComment) <= n && contentStr[j:j+len(sqlLineComment)] == sqlLineComment { + //nolint:revive // Write never returns error + stmtBuilder.Write(bytes.TrimRight(content[i:j], " ")) + for j < n && contentStr[j] != '\n' { + j += 1 + } + i = j + } + + default: + // no-op: generic character + } + } + + if len(st) > 0 { + switch st[len(st)-1] { + case sqlLeftParenthesis: + return nil, fmt.Errorf("error reading contents: unmatched left parenthesis") + case sqlBeginKeyword[0]: + return nil, fmt.Errorf("error reading contents: unmatched `BEGIN` keyword") + default: + // should never enter here + return nil, fmt.Errorf("error reading contents: unmatched `%c`", st[len(st)-1]) + } + } + + //nolint:revive // Write never returns error + stmtBuilder.Write(content[i:j]) + stmt := strings.TrimSpace(stmtBuilder.String()) if stmt == "" { continue } result = append(result, stmt) } - } return result, nil } + +// matchWord is a simple test to check if it matches the whole word: +// it checks if the adjacent charactes are not letters if they exist. +func matchWord(s, word string, pos int) bool { + if pos+len(word) > len(s) || s[pos:pos+len(word)] != word { + return false + } + if pos > 0 && isAlphanumeric(s[pos-1]) { + return false + } + if pos+len(word) < len(s) && isAlphanumeric(s[pos+len(word)]) { + return false + } + return true +} + +func isAlphanumeric(c byte) bool { + return unicode.IsLetter(rune(c)) || unicode.IsDigit(rune(c)) +} diff --git a/common/persistence/query_util_test.go b/common/persistence/query_util_test.go new file mode 100644 index 00000000000..afb1fa1dc50 --- /dev/null +++ b/common/persistence/query_util_test.go @@ -0,0 +1,131 @@ +// The MIT License +// +// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. +// +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package persistence + +import ( + "bytes" + "io" + "testing" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + + "go.temporal.io/server/common/log" +) + +type ( + queryUtilSuite struct { + suite.Suite + // override suite.Suite.Assertions with require.Assertions; this means that s.NotNil(nil) will stop the test, + // not merely log an error + *require.Assertions + logger log.Logger + } +) + +func TestQueryUtilSuite(t *testing.T) { + s := new(queryUtilSuite) + suite.Run(t, s) +} + +func (s *queryUtilSuite) SetupTest() { + s.logger = log.NewTestLogger() + // Have to define our overridden assertions in the test setup. If we did it earlier, s.T() will return nil + s.Assertions = require.New(s.T()) +} + +func (s *queryUtilSuite) TestLoadAndSplitQueryFromReaders() { + input := ` + CREATE TABLE test ( + id BIGINT not null, + col1 BIGINT, -- comment with unmatched parenthesis ) + col2 VARCHAR(255), + PRIMARY KEY (id) + ); + + CREATE INDEX test_idx ON test (col1); + + CREATE TRIGGER test_ai AFTER INSERT ON test + BEGIN + SELECT *, 'string with unmatched chars ")' FROM test; + END; + ` + statements, err := LoadAndSplitQueryFromReaders([]io.Reader{bytes.NewBufferString(input)}) + s.NoError(err) + s.Equal(3, len(statements)) + s.Equal( + `CREATE TABLE test ( + id BIGINT not null, + col1 BIGINT, + col2 VARCHAR(255), + PRIMARY KEY (id) + );`, + statements[0], + ) + s.Equal(`CREATE INDEX test_idx ON test (col1);`, statements[1]) + s.Equal( + `CREATE TRIGGER test_ai AFTER INSERT ON test + BEGIN + SELECT *, 'string with unmatched chars ")' FROM test; + END;`, + statements[2], + ) + + input = "CREATE TABLE test (;" + statements, err = LoadAndSplitQueryFromReaders([]io.Reader{bytes.NewBufferString(input)}) + s.Error(err, "error reading contents: unmatched left parenthesis") + s.Nil(statements) + + input = "CREATE TABLE test ());" + statements, err = LoadAndSplitQueryFromReaders([]io.Reader{bytes.NewBufferString(input)}) + s.Error(err, "error reading contents: unmatched right parenthesis") + s.Nil(statements) + + input = "begin" + statements, err = LoadAndSplitQueryFromReaders([]io.Reader{bytes.NewBufferString(input)}) + s.Error(err, "error reading contents: unmatched `BEGIN` keyword") + s.Nil(statements) + + input = "end" + statements, err = LoadAndSplitQueryFromReaders([]io.Reader{bytes.NewBufferString(input)}) + s.Error(err, "error reading contents: unmatched `END` keyword") + s.Nil(statements) + + input = "select ' from test;" + statements, err = LoadAndSplitQueryFromReaders([]io.Reader{bytes.NewBufferString(input)}) + s.Error(err, "error reading contents: unmatched quotes") + s.Nil(statements) +} + +func (s *queryUtilSuite) TestMatchWord() { + s.True(matchWord("BEGIN", "BEGIN", 0)) + s.True(matchWord(" BEGIN ", "BEGIN", 1)) + s.True(matchWord(")BEGIN;", "BEGIN", 1)) + s.False(matchWord("BEGIN", "BEGIN", 1)) + s.False(matchWord("sBEGIN", "BEGIN", 1)) + s.False(matchWord("BEGINs", "BEGIN", 0)) + s.False(matchWord("7BEGIN", "BEGIN", 1)) + s.False(matchWord("BEGIN7", "BEGIN", 0)) +} diff --git a/tools/common/schema/util.go b/tools/common/schema/util.go index c601dafeff5..00eb1f9ed66 100644 --- a/tools/common/schema/util.go +++ b/tools/common/schema/util.go @@ -25,10 +25,7 @@ package schema import ( - "bufio" - "io" - "os" - "strings" + "go.temporal.io/server/common/persistence" ) const newLineDelim = '\n' @@ -37,43 +34,5 @@ const newLineDelim = '\n' // and returns an array of cql / sql statements on // success. func ParseFile(filePath string) ([]string, error) { - // #nosec - f, err := os.Open(filePath) - if err != nil { - return nil, err - } - - reader := bufio.NewReader(f) - - var line string - var currStmt string - var stmts = make([]string, 0, 4) - - for err == nil { - - line, err = reader.ReadString(newLineDelim) - line = strings.TrimSpace(line) - if len(line) < 1 { - continue - } - - // Filter out the comment lines, the - // only recognized comment line format - // is any line that starts with double dashes - tokens := strings.Split(line, "--") - if len(tokens) > 0 && len(tokens[0]) > 0 { - currStmt += tokens[0] - // semi-colon is the end of statement delim - if strings.HasSuffix(currStmt, ";") { - stmts = append(stmts, currStmt) - currStmt = "" - } - } - } - - if err == io.EOF { - return stmts, nil - } - - return nil, err + return persistence.LoadAndSplitQuery([]string{filePath}) }