diff --git a/common/persistence/query_util.go b/common/persistence/query_util.go index da7a58771b0..548495cc860 100644 --- a/common/persistence/query_util.go +++ b/common/persistence/query_util.go @@ -25,15 +25,24 @@ package persistence import ( + "bytes" "fmt" "io" "os" - "strings" + "unicode" ) const ( - queryDelimiter = ";" + queryDelimiter = ';' querySliceDefaultSize = 100 + + sqlLeftParenthesis = '(' + sqlRightParenthesis = ')' + sqlBeginKeyword = "begin" + sqlEndKeyword = "end" + sqlLineComment = "--" + sqlSingleQuote = '\'' + sqlDoubleQuote = '"' ) // LoadAndSplitQuery loads and split cql / sql query into one statement per string @@ -57,22 +66,106 @@ func LoadAndSplitQuery( func LoadAndSplitQueryFromReaders( readers []io.Reader, ) ([]string, error) { - result := make([]string, 0, querySliceDefaultSize) - for _, r := range readers { content, err := io.ReadAll(r) if err != nil { return nil, fmt.Errorf("error reading contents: %w", err) } - for _, stmt := range strings.Split(string(content), queryDelimiter) { - stmt = strings.TrimSpace(stmt) + n := len(content) + contentStr := string(bytes.ToLower(content)) + for i, j := 0, 0; i < n; i = j { + // stack to keep track of open parenthesis/blocks + var st []byte + + stmtLoop: + for ; j < n; j++ { + switch contentStr[j] { + case queryDelimiter: + if len(st) == 0 { + j += 1 + break stmtLoop + } + + case sqlLeftParenthesis: + st = append(st, sqlLeftParenthesis) + + case sqlRightParenthesis: + if len(st) == 0 || st[len(st)-1] != sqlLeftParenthesis { + return nil, fmt.Errorf("error reading contents: unmatched right parenthesis") + } + st = st[:len(st)-1] + + case sqlBeginKeyword[0]: + if matchWord(contentStr, sqlBeginKeyword, j) { + st = append(st, sqlBeginKeyword[0]) + j += len(sqlBeginKeyword) - 1 + } + + case sqlEndKeyword[0]: + if matchWord(contentStr, sqlEndKeyword, j) { + if len(st) == 0 || st[len(st)-1] != sqlBeginKeyword[0] { + return nil, fmt.Errorf("error reading contents: unmatched `END` keyword") + } + st = st[:len(st)-1] + j += len(sqlEndKeyword) - 1 + } + + case sqlSingleQuote, sqlDoubleQuote: + quote := contentStr[j] + j += 1 + for j < n && contentStr[j] != quote { + j += 1 + } + if j == n { + return nil, fmt.Errorf("error reading contents: unmatched quotes") + } + + case sqlLineComment[0]: + if j+len(sqlLineComment) <= n && contentStr[j:j+len(sqlLineComment)] == sqlLineComment { + for j < n && contentStr[j] != '\n' { + j += 1 + } + } + + default: + // no-op: generic character + } + } + + if len(st) > 0 { + switch st[len(st)-1] { + case sqlLeftParenthesis: + return nil, fmt.Errorf("error reading contents: unmatched left parenthesis") + case sqlBeginKeyword[0]: + return nil, fmt.Errorf("error reading contents: unmatched `BEGIN` keyword") + default: + // should never enter here + return nil, fmt.Errorf("error reading contents: unmatched `%c`", st[len(st)-1]) + } + } + + stmt := string(bytes.TrimSpace(content[i:j])) if stmt == "" { continue } result = append(result, stmt) } - } return result, nil } + +// matchWord is a simple test to check if it matches the whole word: +// it checks if the adjacent charactes are not letters if they exist. +func matchWord(s, word string, pos int) bool { + if pos+len(word) > len(s) || s[pos:pos+len(word)] != word { + return false + } + if pos > 0 && unicode.IsLetter(rune(s[pos-1])) { + return false + } + if pos+len(word) < len(s) && unicode.IsLetter(rune(s[pos+len(word)])) { + return false + } + return true +} diff --git a/common/persistence/query_util_test.go b/common/persistence/query_util_test.go new file mode 100644 index 00000000000..9f4d50621ca --- /dev/null +++ b/common/persistence/query_util_test.go @@ -0,0 +1,129 @@ +// The MIT License +// +// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. +// +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package persistence + +import ( + "bytes" + "io" + "testing" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + + "go.temporal.io/server/common/log" +) + +type ( + queryUtilSuite struct { + suite.Suite + // override suite.Suite.Assertions with require.Assertions; this means that s.NotNil(nil) will stop the test, + // not merely log an error + *require.Assertions + logger log.Logger + } +) + +func TestQueryUtilSuite(t *testing.T) { + s := new(queryUtilSuite) + suite.Run(t, s) +} + +func (s *queryUtilSuite) SetupTest() { + s.logger = log.NewTestLogger() + // Have to define our overridden assertions in the test setup. If we did it earlier, s.T() will return nil + s.Assertions = require.New(s.T()) +} + +func (s *queryUtilSuite) TestLoadAndSplitQueryFromReaders() { + input := ` + CREATE TABLE test ( + id BIGINT not null, + col1 BIGINT, -- comment with unmatched parenthesis ) + col2 VARCHAR(255), + PRIMARY KEY (id) + ); + + CREATE INDEX test_idx ON test (col1); + + CREATE TRIGGER test_ai AFTER INSERT ON test + BEGIN + SELECT *, 'string with unmatched chars ")' FROM test; + END; + ` + statements, err := LoadAndSplitQueryFromReaders([]io.Reader{bytes.NewBufferString(input)}) + s.NoError(err) + s.Equal(3, len(statements)) + s.Equal( + `CREATE TABLE test ( + id BIGINT not null, + col1 BIGINT, -- comment with unmatched parenthesis ) + col2 VARCHAR(255), + PRIMARY KEY (id) + );`, + statements[0], + ) + s.Equal(`CREATE INDEX test_idx ON test (col1);`, statements[1]) + s.Equal( + `CREATE TRIGGER test_ai AFTER INSERT ON test + BEGIN + SELECT *, 'string with unmatched chars ")' FROM test; + END;`, + statements[2], + ) + + input = "CREATE TABLE test (;" + statements, err = LoadAndSplitQueryFromReaders([]io.Reader{bytes.NewBufferString(input)}) + s.Error(err, "error reading contents: unmatched left parenthesis") + s.Nil(statements) + + input = "CREATE TABLE test ());" + statements, err = LoadAndSplitQueryFromReaders([]io.Reader{bytes.NewBufferString(input)}) + s.Error(err, "error reading contents: unmatched right parenthesis") + s.Nil(statements) + + input = "begin" + statements, err = LoadAndSplitQueryFromReaders([]io.Reader{bytes.NewBufferString(input)}) + s.Error(err, "error reading contents: unmatched `BEGIN` keyword") + s.Nil(statements) + + input = "end" + statements, err = LoadAndSplitQueryFromReaders([]io.Reader{bytes.NewBufferString(input)}) + s.Error(err, "error reading contents: unmatched `END` keyword") + s.Nil(statements) + + input = "select ' from test;" + statements, err = LoadAndSplitQueryFromReaders([]io.Reader{bytes.NewBufferString(input)}) + s.Error(err, "error reading contents: unmatched quotes") + s.Nil(statements) +} + +func (s *queryUtilSuite) TestMatchWord() { + s.True(matchWord("BEGIN", "BEGIN", 0)) + s.True(matchWord(" BEGIN ", "BEGIN", 1)) + s.True(matchWord(")BEGIN;", "BEGIN", 1)) + s.False(matchWord("BEGIN", "BEGIN", 1)) + s.False(matchWord("sBEGIN", "BEGIN", 1)) + s.False(matchWord("BEGINs", "BEGIN", 1)) +}