Skip to content

Commit

Permalink
Support CountWorkflowExecutions for SQL DB (#3955)
Browse files Browse the repository at this point in the history
  • Loading branch information
rodrigozhou authored Feb 15, 2023
1 parent f4e4ea2 commit 099cef3
Show file tree
Hide file tree
Showing 13 changed files with 229 additions and 35 deletions.
8 changes: 8 additions & 0 deletions common/persistence/sql/sqlplugin/mysql/visibility.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"fmt"

"go.temporal.io/server/common/persistence/sql/sqlplugin"
"go.temporal.io/server/common/persistence/visibility/store"
)

const (
Expand Down Expand Up @@ -297,6 +298,13 @@ func (mdb *db) GetFromVisibility(
return &row, nil
}

func (mdb *db) CountFromVisibility(
ctx context.Context,
filter sqlplugin.VisibilitySelectFilter,
) (int64, error) {
return 0, store.OperationNotSupportedErr
}

func (mdb *db) processRowFromDB(row *sqlplugin.VisibilityRow) {
row.StartTime = mdb.converter.FromMySQLDateTime(row.StartTime)
row.ExecutionTime = mdb.converter.FromMySQLDateTime(row.ExecutionTime)
Expand Down
12 changes: 12 additions & 0 deletions common/persistence/sql/sqlplugin/mysql/visibility_v8.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,18 @@ func (mdb *dbV8) GetFromVisibility(
return &row, nil
}

func (mdb *dbV8) CountFromVisibility(
ctx context.Context,
filter sqlplugin.VisibilitySelectFilter,
) (int64, error) {
var count int64
err := mdb.conn.GetContext(ctx, &count, filter.Query, filter.QueryArgs...)
if err != nil {
return 0, err
}
return count, nil
}

func (mdb *dbV8) prepareRowForDB(row *sqlplugin.VisibilityRow) *sqlplugin.VisibilityRow {
if row == nil {
return nil
Expand Down
8 changes: 8 additions & 0 deletions common/persistence/sql/sqlplugin/postgresql/visibility.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"strings"

"go.temporal.io/server/common/persistence/sql/sqlplugin"
"go.temporal.io/server/common/persistence/visibility/store"
)

const (
Expand Down Expand Up @@ -316,6 +317,13 @@ func (pdb *db) GetFromVisibility(
return &row, nil
}

func (pdb *db) CountFromVisibility(
ctx context.Context,
filter sqlplugin.VisibilitySelectFilter,
) (int64, error) {
return 0, store.OperationNotSupportedErr
}

func (pdb *db) processRowFromDB(row *sqlplugin.VisibilityRow) {
row.StartTime = pdb.converter.FromPostgreSQLDateTime(row.StartTime)
row.ExecutionTime = pdb.converter.FromPostgreSQLDateTime(row.ExecutionTime)
Expand Down
13 changes: 13 additions & 0 deletions common/persistence/sql/sqlplugin/postgresql/visibility_v12.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,19 @@ func (pdb *dbV12) GetFromVisibility(
return &row, nil
}

func (pdb *dbV12) CountFromVisibility(
ctx context.Context,
filter sqlplugin.VisibilitySelectFilter,
) (int64, error) {
var count int64
filter.Query = pdb.db.db.Rebind(filter.Query)
err := pdb.conn.GetContext(ctx, &count, filter.Query, filter.QueryArgs...)
if err != nil {
return 0, err
}
return count, nil
}

func (pdb *dbV12) prepareRowForDB(row *sqlplugin.VisibilityRow) *sqlplugin.VisibilityRow {
if row == nil {
return nil
Expand Down
12 changes: 12 additions & 0 deletions common/persistence/sql/sqlplugin/sqlite/visibility.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,18 @@ func (mdb *db) GetFromVisibility(
return &row, nil
}

func (mdb *db) CountFromVisibility(
ctx context.Context,
filter sqlplugin.VisibilitySelectFilter,
) (int64, error) {
var count int64
err := mdb.conn.GetContext(ctx, &count, filter.Query, filter.QueryArgs...)
if err != nil {
return 0, err
}
return count, nil
}

func (mdb *db) prepareRowForDB(row *sqlplugin.VisibilityRow) *sqlplugin.VisibilityRow {
if row == nil {
return nil
Expand Down
1 change: 1 addition & 0 deletions common/persistence/sql/sqlplugin/visibility.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ type (
SelectFromVisibility(ctx context.Context, filter VisibilitySelectFilter) ([]VisibilityRow, error)
GetFromVisibility(ctx context.Context, filter VisibilityGetFilter) (*VisibilityRow, error)
DeleteFromVisibility(ctx context.Context, filter VisibilityDeleteFilter) (sql.Result, error)
CountFromVisibility(ctx context.Context, filter VisibilitySelectFilter) (int64, error)
}
)

Expand Down
46 changes: 34 additions & 12 deletions common/persistence/visibility/store/sql/query_converter.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ import (
enumspb "go.temporal.io/api/enums/v1"
"go.temporal.io/server/common/namespace"
"go.temporal.io/server/common/persistence/sql/sqlplugin"
"go.temporal.io/server/common/persistence/visibility/manager"
"go.temporal.io/server/common/persistence/visibility/store/query"
"go.temporal.io/server/common/searchattribute"
)
Expand All @@ -54,16 +53,20 @@ type (
token *pageToken,
) (string, []any)

buildCountStmt(namespaceID namespace.ID, queryString string) (string, []any)

getDatetimeFormat() string

getCoalesceCloseTimeExpr() sqlparser.Expr
}

QueryConverter struct {
pluginQueryConverter
request *manager.ListWorkflowExecutionsRequestV2
saTypeMap searchattribute.NameTypeMap
saMapper searchattribute.Mapper
namespaceName namespace.Name
namespaceID namespace.ID
saTypeMap searchattribute.NameTypeMap
saMapper searchattribute.Mapper
queryString string

seenNamespaceDivision bool
}
Expand Down Expand Up @@ -114,38 +117,57 @@ var (

func newQueryConverterInternal(
pqc pluginQueryConverter,
request *manager.ListWorkflowExecutionsRequestV2,
namespaceName namespace.Name,
namespaceID namespace.ID,
saTypeMap searchattribute.NameTypeMap,
saMapper searchattribute.Mapper,
queryString string,
) *QueryConverter {
return &QueryConverter{
pluginQueryConverter: pqc,
request: request,
namespaceName: namespaceName,
namespaceID: namespaceID,
saTypeMap: saTypeMap,
saMapper: saMapper,
queryString: queryString,

seenNamespaceDivision: false,
}
}

func (c *QueryConverter) BuildSelectStmt() (*sqlplugin.VisibilitySelectFilter, error) {
token, err := deserializePageToken(c.request.NextPageToken)
func (c *QueryConverter) BuildSelectStmt(
pageSize int,
nextPageToken []byte,
) (*sqlplugin.VisibilitySelectFilter, error) {
token, err := deserializePageToken(nextPageToken)
if err != nil {
return nil, err
}
queryString, err := c.convertWhereString(c.request.Query)
queryString, err := c.convertWhereString(c.queryString)
if err != nil {
return nil, err
}
queryString, queryArgs := c.buildSelectStmt(
c.request.NamespaceID,
c.namespaceID,
queryString,
c.request.PageSize,
pageSize,
token,
)
return &sqlplugin.VisibilitySelectFilter{Query: queryString, QueryArgs: queryArgs}, nil
}

func (c *QueryConverter) BuildCountStmt() (*sqlplugin.VisibilitySelectFilter, error) {
queryString, err := c.convertWhereString(c.queryString)
if err != nil {
return nil, err
}
queryString, queryArgs := c.buildCountStmt(
c.namespaceID,
queryString,
)
return &sqlplugin.VisibilitySelectFilter{Query: queryString, QueryArgs: queryArgs}, nil
}

func (c *QueryConverter) convertWhereString(queryString string) (string, error) {
where := strings.TrimSpace(queryString)
if where != "" && !strings.HasPrefix(strings.ToLower(where), "order by") {
Expand Down Expand Up @@ -376,7 +398,7 @@ func (c *QueryConverter) convertColName(
saFieldName = saAlias
if searchattribute.IsMappable(saAlias) {
var err error
saFieldName, err = c.saMapper.GetFieldName(saAlias, c.request.Namespace.String())
saFieldName, err = c.saMapper.GetFieldName(saAlias, c.namespaceName.String())
if err != nil {
return "", "", err
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,26 +25,28 @@
package sql

import (
"go.temporal.io/server/common/namespace"
"go.temporal.io/server/common/persistence/sql/sqlplugin/mysql"
"go.temporal.io/server/common/persistence/sql/sqlplugin/postgresql"
"go.temporal.io/server/common/persistence/sql/sqlplugin/sqlite"
"go.temporal.io/server/common/persistence/visibility/manager"
"go.temporal.io/server/common/searchattribute"
)

func NewQueryConverter(
pluginName string,
request *manager.ListWorkflowExecutionsRequestV2,
namespaceName namespace.Name,
namespaceID namespace.ID,
saTypeMap searchattribute.NameTypeMap,
saMapper searchattribute.Mapper,
queryString string,
) *QueryConverter {
switch pluginName {
case mysql.PluginNameV8:
return newMySQLQueryConverter(request, saTypeMap, saMapper)
return newMySQLQueryConverter(namespaceName, namespaceID, saTypeMap, saMapper, queryString)
case postgresql.PluginNameV12:
return newPostgreSQLQueryConverter(request, saTypeMap, saMapper)
return newPostgreSQLQueryConverter(namespaceName, namespaceID, saTypeMap, saMapper, queryString)
case sqlite.PluginName:
return newSqliteQueryConverter(request, saTypeMap, saMapper)
return newSqliteQueryConverter(namespaceName, namespaceID, saTypeMap, saMapper, queryString)
default:
return nil
}
Expand Down
38 changes: 35 additions & 3 deletions common/persistence/visibility/store/sql/query_converter_mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ import (
"github.com/xwb1989/sqlparser"
"go.temporal.io/server/common/namespace"
"go.temporal.io/server/common/persistence/sql/sqlplugin"
"go.temporal.io/server/common/persistence/visibility/manager"
"go.temporal.io/server/common/persistence/visibility/store/query"
"go.temporal.io/server/common/searchattribute"
)
Expand Down Expand Up @@ -83,15 +82,19 @@ func (node *jsonOverlapsExpr) Format(buf *sqlparser.TrackedBuffer) {
}

func newMySQLQueryConverter(
request *manager.ListWorkflowExecutionsRequestV2,
namespaceName namespace.Name,
namespaceID namespace.ID,
saTypeMap searchattribute.NameTypeMap,
saMapper searchattribute.Mapper,
queryString string,
) *QueryConverter {
return newQueryConverterInternal(
&mysqlQueryConverter{},
request,
namespaceName,
namespaceID,
saTypeMap,
saMapper,
queryString,
)
}

Expand Down Expand Up @@ -253,3 +256,32 @@ func (c *mysqlQueryConverter) buildSelectStmt(
searchattribute.GetSqlDbColName(searchattribute.RunID),
), queryArgs
}

func (c *mysqlQueryConverter) buildCountStmt(
namespaceID namespace.ID,
queryString string,
) (string, []any) {
var whereClauses []string
var queryArgs []any

whereClauses = append(
whereClauses,
fmt.Sprintf("(%s = ?)", searchattribute.GetSqlDbColName(searchattribute.NamespaceID)),
)
queryArgs = append(queryArgs, namespaceID.String())

if len(queryString) > 0 {
whereClauses = append(whereClauses, fmt.Sprintf("(%s)", queryString))
}

return fmt.Sprintf(
`SELECT COUNT(1)
FROM executions_visibility ev
LEFT JOIN custom_search_attributes
USING (%s, %s)
WHERE %s`,
searchattribute.GetSqlDbColName(searchattribute.NamespaceID),
searchattribute.GetSqlDbColName(searchattribute.RunID),
strings.Join(whereClauses, " AND "),
), queryArgs
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import (
"github.com/xwb1989/sqlparser"
"go.temporal.io/server/common/namespace"
"go.temporal.io/server/common/persistence/sql/sqlplugin"
"go.temporal.io/server/common/persistence/visibility/manager"
"go.temporal.io/server/common/persistence/visibility/store/query"
"go.temporal.io/server/common/searchattribute"
)
Expand Down Expand Up @@ -64,15 +63,19 @@ func (node *pgCastExpr) Format(buf *sqlparser.TrackedBuffer) {
}

func newPostgreSQLQueryConverter(
request *manager.ListWorkflowExecutionsRequestV2,
namespaceName namespace.Name,
namespaceID namespace.ID,
saTypeMap searchattribute.NameTypeMap,
saMapper searchattribute.Mapper,
queryString string,
) *QueryConverter {
return newQueryConverterInternal(
&pgQueryConverter{},
request,
namespaceName,
namespaceID,
saTypeMap,
saMapper,
queryString,
)
}

Expand Down Expand Up @@ -256,3 +259,26 @@ func (c *pgQueryConverter) buildSelectStmt(
searchattribute.GetSqlDbColName(searchattribute.RunID),
), queryArgs
}

func (c *pgQueryConverter) buildCountStmt(
namespaceID namespace.ID,
queryString string,
) (string, []any) {
var whereClauses []string
var queryArgs []any

whereClauses = append(
whereClauses,
fmt.Sprintf("(%s = ?)", searchattribute.GetSqlDbColName(searchattribute.NamespaceID)),
)
queryArgs = append(queryArgs, namespaceID.String())

if len(queryString) > 0 {
whereClauses = append(whereClauses, fmt.Sprintf("(%s)", queryString))
}

return fmt.Sprintf(
"SELECT COUNT(1) FROM executions_visibility WHERE %s",
strings.Join(whereClauses, " AND "),
), queryArgs
}
Loading

0 comments on commit 099cef3

Please sign in to comment.