diff --git a/common/persistence/query_util.go b/common/persistence/query_util.go index da7a58771b0..061001e251b 100644 --- a/common/persistence/query_util.go +++ b/common/persistence/query_util.go @@ -25,18 +25,29 @@ package persistence import ( + "bytes" "fmt" "io" "os" "strings" + "unicode" ) const ( - queryDelimiter = ";" + queryDelimiter = ';' querySliceDefaultSize = 100 + + sqlLeftParenthesis = '(' + sqlRightParenthesis = ')' + sqlBeginKeyword = "begin" + sqlEndKeyword = "end" + sqlLineComment = "--" + sqlSingleQuote = '\'' + sqlDoubleQuote = '"' ) -// LoadAndSplitQuery loads and split cql / sql query into one statement per string +// LoadAndSplitQuery loads and split cql / sql query into one statement per string. +// Comments are removed from the query. func LoadAndSplitQuery( filePaths []string, ) ([]string, error) { @@ -53,26 +64,119 @@ func LoadAndSplitQuery( return LoadAndSplitQueryFromReaders(files) } -// LoadAndSplitQueryFromReaders loads and split cql / sql query into one statement per string +// LoadAndSplitQueryFromReaders loads and split cql / sql query into one statement per string. +// Comments are removed from the query. func LoadAndSplitQueryFromReaders( readers []io.Reader, ) ([]string, error) { - result := make([]string, 0, querySliceDefaultSize) - for _, r := range readers { content, err := io.ReadAll(r) if err != nil { return nil, fmt.Errorf("error reading contents: %w", err) } - for _, stmt := range strings.Split(string(content), queryDelimiter) { - stmt = strings.TrimSpace(stmt) + n := len(content) + contentStr := string(bytes.ToLower(content)) + for i, j := 0, 0; i < n; i = j { + // stack to keep track of open parenthesis/blocks + var st []byte + var stmtBuilder strings.Builder + + stmtLoop: + for ; j < n; j++ { + switch contentStr[j] { + case queryDelimiter: + if len(st) == 0 { + j++ + break stmtLoop + } + + case sqlLeftParenthesis: + st = append(st, sqlLeftParenthesis) + + case sqlRightParenthesis: + if len(st) == 0 || st[len(st)-1] != sqlLeftParenthesis { + return nil, fmt.Errorf("error reading contents: unmatched right parenthesis") + } + st = st[:len(st)-1] + + case sqlBeginKeyword[0]: + if hasWordAt(contentStr, sqlBeginKeyword, j) { + st = append(st, sqlBeginKeyword[0]) + j += len(sqlBeginKeyword) - 1 + } + + case sqlEndKeyword[0]: + if hasWordAt(contentStr, sqlEndKeyword, j) { + if len(st) == 0 || st[len(st)-1] != sqlBeginKeyword[0] { + return nil, fmt.Errorf("error reading contents: unmatched `END` keyword") + } + st = st[:len(st)-1] + j += len(sqlEndKeyword) - 1 + } + + case sqlSingleQuote, sqlDoubleQuote: + quote := contentStr[j] + j++ + for j < n && contentStr[j] != quote { + j++ + } + if j == n { + return nil, fmt.Errorf("error reading contents: unmatched quotes") + } + + case sqlLineComment[0]: + if j+len(sqlLineComment) <= n && contentStr[j:j+len(sqlLineComment)] == sqlLineComment { + _, _ = stmtBuilder.Write(bytes.TrimRight(content[i:j], " ")) + for j < n && contentStr[j] != '\n' { + j++ + } + i = j + } + + default: + // no-op: generic character + } + } + + if len(st) > 0 { + switch st[len(st)-1] { + case sqlLeftParenthesis: + return nil, fmt.Errorf("error reading contents: unmatched left parenthesis") + case sqlBeginKeyword[0]: + return nil, fmt.Errorf("error reading contents: unmatched `BEGIN` keyword") + default: + // should never enter here + return nil, fmt.Errorf("error reading contents: unmatched `%c`", st[len(st)-1]) + } + } + + _, _ = stmtBuilder.Write(content[i:j]) + stmt := strings.TrimSpace(stmtBuilder.String()) if stmt == "" { continue } result = append(result, stmt) } - } return result, nil } + +// hasWordAt is a simple test to check if it matches the whole word: +// it checks if the adjacent charactes are not alphanumeric if they exist. +func hasWordAt(s, word string, pos int) bool { + if pos+len(word) > len(s) || s[pos:pos+len(word)] != word { + return false + } + if pos > 0 && isAlphanumeric(s[pos-1]) { + return false + } + if pos+len(word) < len(s) && isAlphanumeric(s[pos+len(word)]) { + return false + } + return true +} + +func isAlphanumeric(c byte) bool { + return unicode.IsLetter(rune(c)) || unicode.IsDigit(rune(c)) +} diff --git a/common/persistence/query_util_test.go b/common/persistence/query_util_test.go new file mode 100644 index 00000000000..ae5ff96181a --- /dev/null +++ b/common/persistence/query_util_test.go @@ -0,0 +1,137 @@ +// The MIT License +// +// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. +// +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package persistence + +import ( + "bytes" + "io" + "testing" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + + "go.temporal.io/server/common/log" +) + +type ( + queryUtilSuite struct { + suite.Suite + // override suite.Suite.Assertions with require.Assertions; this means that s.NotNil(nil) will stop the test, + // not merely log an error + *require.Assertions + logger log.Logger + } +) + +func TestQueryUtilSuite(t *testing.T) { + s := new(queryUtilSuite) + suite.Run(t, s) +} + +func (s *queryUtilSuite) SetupTest() { + s.logger = log.NewTestLogger() + // Have to define our overridden assertions in the test setup. If we did it earlier, s.T() will return nil + s.Assertions = require.New(s.T()) +} + +func (s *queryUtilSuite) TestLoadAndSplitQueryFromReaders() { + input := ` + CREATE TABLE test ( + id BIGINT not null, + col1 BIGINT, -- comment with unmatched parenthesis ) + col2 VARCHAR(255), + PRIMARY KEY (id) + ); + + CREATE INDEX test_idx ON test (col1); + + --begin + CREATE TRIGGER test_ai AFTER INSERT ON test + BEGIN + SELECT *, 'string with unmatched chars ")' FROM test; + --end + END; + + -- trailing comment + ` + statements, err := LoadAndSplitQueryFromReaders([]io.Reader{bytes.NewBufferString(input)}) + s.NoError(err) + s.Equal(3, len(statements)) + s.Equal( + `CREATE TABLE test ( + id BIGINT not null, + col1 BIGINT, + col2 VARCHAR(255), + PRIMARY KEY (id) + );`, + statements[0], + ) + s.Equal(`CREATE INDEX test_idx ON test (col1);`, statements[1]) + // comments are removed, but the inner content is not trimmed + s.Equal( + `CREATE TRIGGER test_ai AFTER INSERT ON test + BEGIN + SELECT *, 'string with unmatched chars ")' FROM test; + + END;`, + statements[2], + ) + + input = "CREATE TABLE test (;" + statements, err = LoadAndSplitQueryFromReaders([]io.Reader{bytes.NewBufferString(input)}) + s.Error(err, "error reading contents: unmatched left parenthesis") + s.Nil(statements) + + input = "CREATE TABLE test ());" + statements, err = LoadAndSplitQueryFromReaders([]io.Reader{bytes.NewBufferString(input)}) + s.Error(err, "error reading contents: unmatched right parenthesis") + s.Nil(statements) + + input = "begin" + statements, err = LoadAndSplitQueryFromReaders([]io.Reader{bytes.NewBufferString(input)}) + s.Error(err, "error reading contents: unmatched `BEGIN` keyword") + s.Nil(statements) + + input = "end" + statements, err = LoadAndSplitQueryFromReaders([]io.Reader{bytes.NewBufferString(input)}) + s.Error(err, "error reading contents: unmatched `END` keyword") + s.Nil(statements) + + input = "select ' from test;" + statements, err = LoadAndSplitQueryFromReaders([]io.Reader{bytes.NewBufferString(input)}) + s.Error(err, "error reading contents: unmatched quotes") + s.Nil(statements) +} + +func (s *queryUtilSuite) TestHasWordAt() { + s.True(hasWordAt("BEGIN", "BEGIN", 0)) + s.True(hasWordAt(" BEGIN ", "BEGIN", 1)) + s.True(hasWordAt(")BEGIN;", "BEGIN", 1)) + s.False(hasWordAt("BEGIN", "BEGIN", 1)) + s.False(hasWordAt("sBEGIN", "BEGIN", 1)) + s.False(hasWordAt("BEGINs", "BEGIN", 0)) + s.False(hasWordAt("7BEGIN", "BEGIN", 1)) + s.False(hasWordAt("BEGIN7", "BEGIN", 0)) +} diff --git a/tools/common/schema/setuptask.go b/tools/common/schema/setuptask.go index 04227b4ee71..63e8a259ab1 100644 --- a/tools/common/schema/setuptask.go +++ b/tools/common/schema/setuptask.go @@ -32,6 +32,7 @@ import ( "go.temporal.io/server/common/log" "go.temporal.io/server/common/log/tag" + "go.temporal.io/server/common/persistence" ) // SetupTask represents a task @@ -75,7 +76,7 @@ func (task *SetupTask) Run() error { if err != nil { return err } - stmts, err := ParseFile(filePath) + stmts, err := persistence.LoadAndSplitQuery([]string{filePath}) if err != nil { return err } diff --git a/tools/common/schema/test/dbtest.go b/tools/common/schema/test/dbtest.go index 5932ae9e33a..7b3d2866081 100644 --- a/tools/common/schema/test/dbtest.go +++ b/tools/common/schema/test/dbtest.go @@ -34,6 +34,7 @@ import ( "go.temporal.io/server/common/log" "go.temporal.io/server/common/log/tag" + "go.temporal.io/server/common/persistence" "go.temporal.io/server/tests/testutils" "go.temporal.io/server/tools/common/schema" ) @@ -83,7 +84,7 @@ func (tb *DBTestBase) RunParseFileTest(content string) { _, err := cqlFile.WriteString(content) tb.NoError(err) - stmts, err := schema.ParseFile(cqlFile.Name()) + stmts, err := persistence.LoadAndSplitQuery([]string{cqlFile.Name()}) tb.Nil(err) tb.Equal(2, len(stmts), "wrong number of sql statements") } diff --git a/tools/common/schema/updatetask.go b/tools/common/schema/updatetask.go index 0bcb3acb4a7..391b51c3578 100644 --- a/tools/common/schema/updatetask.go +++ b/tools/common/schema/updatetask.go @@ -41,6 +41,7 @@ import ( "go.temporal.io/server/common/log" "go.temporal.io/server/common/log/tag" + "go.temporal.io/server/common/persistence" ) type ( @@ -230,7 +231,7 @@ func (task *UpdateTask) parseSQLStmts(dir string, manifest *manifest) ([]string, for _, file := range manifest.SchemaUpdateCqlFiles { path := dir + "/" + file task.logger.Info("Processing schema file: " + path) - stmts, err := ParseFile(path) + stmts, err := persistence.LoadAndSplitQuery([]string{path}) if err != nil { return nil, fmt.Errorf("error parsing file %v, err=%v", path, err) } diff --git a/tools/common/schema/util.go b/tools/common/schema/util.go deleted file mode 100644 index c601dafeff5..00000000000 --- a/tools/common/schema/util.go +++ /dev/null @@ -1,79 +0,0 @@ -// The MIT License -// -// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. -// -// Copyright (c) 2020 Uber Technologies, Inc. -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -// THE SOFTWARE. - -package schema - -import ( - "bufio" - "io" - "os" - "strings" -) - -const newLineDelim = '\n' - -// ParseFile takes a cql / sql file path as input -// and returns an array of cql / sql statements on -// success. -func ParseFile(filePath string) ([]string, error) { - // #nosec - f, err := os.Open(filePath) - if err != nil { - return nil, err - } - - reader := bufio.NewReader(f) - - var line string - var currStmt string - var stmts = make([]string, 0, 4) - - for err == nil { - - line, err = reader.ReadString(newLineDelim) - line = strings.TrimSpace(line) - if len(line) < 1 { - continue - } - - // Filter out the comment lines, the - // only recognized comment line format - // is any line that starts with double dashes - tokens := strings.Split(line, "--") - if len(tokens) > 0 && len(tokens[0]) > 0 { - currStmt += tokens[0] - // semi-colon is the end of statement delim - if strings.HasSuffix(currStmt, ";") { - stmts = append(stmts, currStmt) - currStmt = "" - } - } - } - - if err == io.EOF { - return stmts, nil - } - - return nil, err -}