Skip to content

Commit

Permalink
Fix wrap user query string in parenthesis (#3967)
Browse files Browse the repository at this point in the history
  • Loading branch information
rodrigozhou authored Feb 16, 2023
1 parent 7ab7e9c commit cbed0fe
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 6 deletions.
11 changes: 11 additions & 0 deletions common/persistence/visibility/store/sql/query_converter.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,17 @@ func (c *QueryConverter) convertSelectStmt(sel *sqlparser.Select) error {
if err != nil {
return err
}

// Wrap user's query in parenthesis. This is to ensure that further changes
// to the query won't affect the user's query.
switch sel.Where.Expr.(type) {
case *sqlparser.ParenExpr:
// no-op: top-level expression is already a parenthesis
default:
sel.Where.Expr = &sqlparser.ParenExpr{
Expr: sel.Where.Expr,
}
}
}

// This logic comes from elasticsearch/visibility_store.go#convertQuery function.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ func (c *mysqlQueryConverter) buildSelectStmt(
queryArgs = append(queryArgs, namespaceID.String())

if len(queryString) > 0 {
whereClauses = append(whereClauses, fmt.Sprintf("(%s)", queryString))
whereClauses = append(whereClauses, queryString)
}

if token != nil {
Expand Down Expand Up @@ -283,7 +283,7 @@ func (c *mysqlQueryConverter) buildCountStmt(
queryArgs = append(queryArgs, namespaceID.String())

if len(queryString) > 0 {
whereClauses = append(whereClauses, fmt.Sprintf("(%s)", queryString))
whereClauses = append(whereClauses, queryString)
}

return fmt.Sprintf(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ func (c *pgQueryConverter) buildSelectStmt(
queryArgs = append(queryArgs, namespaceID.String())

if len(queryString) > 0 {
whereClauses = append(whereClauses, fmt.Sprintf("(%s)", queryString))
whereClauses = append(whereClauses, queryString)
}

if token != nil {
Expand Down Expand Up @@ -286,7 +286,7 @@ func (c *pgQueryConverter) buildCountStmt(
queryArgs = append(queryArgs, namespaceID.String())

if len(queryString) > 0 {
whereClauses = append(whereClauses, fmt.Sprintf("(%s)", queryString))
whereClauses = append(whereClauses, queryString)
}

return fmt.Sprintf(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ func (c *sqliteQueryConverter) buildSelectStmt(
queryArgs = append(queryArgs, namespaceID.String())

if len(queryString) > 0 {
whereClauses = append(whereClauses, fmt.Sprintf("(%s)", queryString))
whereClauses = append(whereClauses, queryString)
}

if token != nil {
Expand Down Expand Up @@ -328,7 +328,7 @@ func (c *sqliteQueryConverter) buildCountStmt(
queryArgs = append(queryArgs, namespaceID.String())

if len(queryString) > 0 {
whereClauses = append(whereClauses, fmt.Sprintf("(%s)", queryString))
whereClauses = append(whereClauses, queryString)
}

return fmt.Sprintf(
Expand Down

0 comments on commit cbed0fe

Please sign in to comment.