From 099cef3e2aff60318bfb71be6d83f9d928b2c411 Mon Sep 17 00:00:00 2001 From: Rodrigo Zhou Date: Tue, 14 Feb 2023 16:59:12 -0800 Subject: [PATCH] Support CountWorkflowExecutions for SQL DB (#3955) --- .../sql/sqlplugin/mysql/visibility.go | 8 ++++ .../sql/sqlplugin/mysql/visibility_v8.go | 12 +++++ .../sql/sqlplugin/postgresql/visibility.go | 8 ++++ .../sqlplugin/postgresql/visibility_v12.go | 13 ++++++ .../sql/sqlplugin/sqlite/visibility.go | 12 +++++ .../persistence/sql/sqlplugin/visibility.go | 1 + .../visibility/store/sql/query_converter.go | 46 ++++++++++++++----- .../store/sql/query_converter_factory.go | 12 +++-- .../store/sql/query_converter_mysql.go | 38 +++++++++++++-- .../store/sql/query_converter_postgresql.go | 32 +++++++++++-- .../store/sql/query_converter_sqlite.go | 32 +++++++++++-- .../visibility/store/sql/visibility_store.go | 46 +++++++++++++++++-- tests/advanced_visibility_test.go | 4 -- 13 files changed, 229 insertions(+), 35 deletions(-) diff --git a/common/persistence/sql/sqlplugin/mysql/visibility.go b/common/persistence/sql/sqlplugin/mysql/visibility.go index bfcbaa26770..edb7cce2519 100644 --- a/common/persistence/sql/sqlplugin/mysql/visibility.go +++ b/common/persistence/sql/sqlplugin/mysql/visibility.go @@ -31,6 +31,7 @@ import ( "fmt" "go.temporal.io/server/common/persistence/sql/sqlplugin" + "go.temporal.io/server/common/persistence/visibility/store" ) const ( @@ -297,6 +298,13 @@ func (mdb *db) GetFromVisibility( return &row, nil } +func (mdb *db) CountFromVisibility( + ctx context.Context, + filter sqlplugin.VisibilitySelectFilter, +) (int64, error) { + return 0, store.OperationNotSupportedErr +} + func (mdb *db) processRowFromDB(row *sqlplugin.VisibilityRow) { row.StartTime = mdb.converter.FromMySQLDateTime(row.StartTime) row.ExecutionTime = mdb.converter.FromMySQLDateTime(row.ExecutionTime) diff --git a/common/persistence/sql/sqlplugin/mysql/visibility_v8.go b/common/persistence/sql/sqlplugin/mysql/visibility_v8.go index 15987299f59..b9c10b0b279 100644 --- a/common/persistence/sql/sqlplugin/mysql/visibility_v8.go +++ b/common/persistence/sql/sqlplugin/mysql/visibility_v8.go @@ -234,6 +234,18 @@ func (mdb *dbV8) GetFromVisibility( return &row, nil } +func (mdb *dbV8) CountFromVisibility( + ctx context.Context, + filter sqlplugin.VisibilitySelectFilter, +) (int64, error) { + var count int64 + err := mdb.conn.GetContext(ctx, &count, filter.Query, filter.QueryArgs...) + if err != nil { + return 0, err + } + return count, nil +} + func (mdb *dbV8) prepareRowForDB(row *sqlplugin.VisibilityRow) *sqlplugin.VisibilityRow { if row == nil { return nil diff --git a/common/persistence/sql/sqlplugin/postgresql/visibility.go b/common/persistence/sql/sqlplugin/postgresql/visibility.go index 3db432497ae..e58d6fa0b4e 100644 --- a/common/persistence/sql/sqlplugin/postgresql/visibility.go +++ b/common/persistence/sql/sqlplugin/postgresql/visibility.go @@ -32,6 +32,7 @@ import ( "strings" "go.temporal.io/server/common/persistence/sql/sqlplugin" + "go.temporal.io/server/common/persistence/visibility/store" ) const ( @@ -316,6 +317,13 @@ func (pdb *db) GetFromVisibility( return &row, nil } +func (pdb *db) CountFromVisibility( + ctx context.Context, + filter sqlplugin.VisibilitySelectFilter, +) (int64, error) { + return 0, store.OperationNotSupportedErr +} + func (pdb *db) processRowFromDB(row *sqlplugin.VisibilityRow) { row.StartTime = pdb.converter.FromPostgreSQLDateTime(row.StartTime) row.ExecutionTime = pdb.converter.FromPostgreSQLDateTime(row.ExecutionTime) diff --git a/common/persistence/sql/sqlplugin/postgresql/visibility_v12.go b/common/persistence/sql/sqlplugin/postgresql/visibility_v12.go index 6c9543e3e01..a9879085951 100644 --- a/common/persistence/sql/sqlplugin/postgresql/visibility_v12.go +++ b/common/persistence/sql/sqlplugin/postgresql/visibility_v12.go @@ -150,6 +150,19 @@ func (pdb *dbV12) GetFromVisibility( return &row, nil } +func (pdb *dbV12) CountFromVisibility( + ctx context.Context, + filter sqlplugin.VisibilitySelectFilter, +) (int64, error) { + var count int64 + filter.Query = pdb.db.db.Rebind(filter.Query) + err := pdb.conn.GetContext(ctx, &count, filter.Query, filter.QueryArgs...) + if err != nil { + return 0, err + } + return count, nil +} + func (pdb *dbV12) prepareRowForDB(row *sqlplugin.VisibilityRow) *sqlplugin.VisibilityRow { if row == nil { return nil diff --git a/common/persistence/sql/sqlplugin/sqlite/visibility.go b/common/persistence/sql/sqlplugin/sqlite/visibility.go index c2134b8d274..0604322ecc3 100644 --- a/common/persistence/sql/sqlplugin/sqlite/visibility.go +++ b/common/persistence/sql/sqlplugin/sqlite/visibility.go @@ -152,6 +152,18 @@ func (mdb *db) GetFromVisibility( return &row, nil } +func (mdb *db) CountFromVisibility( + ctx context.Context, + filter sqlplugin.VisibilitySelectFilter, +) (int64, error) { + var count int64 + err := mdb.conn.GetContext(ctx, &count, filter.Query, filter.QueryArgs...) + if err != nil { + return 0, err + } + return count, nil +} + func (mdb *db) prepareRowForDB(row *sqlplugin.VisibilityRow) *sqlplugin.VisibilityRow { if row == nil { return nil diff --git a/common/persistence/sql/sqlplugin/visibility.go b/common/persistence/sql/sqlplugin/visibility.go index f46973d1b4d..0bf8076a213 100644 --- a/common/persistence/sql/sqlplugin/visibility.go +++ b/common/persistence/sql/sqlplugin/visibility.go @@ -105,6 +105,7 @@ type ( SelectFromVisibility(ctx context.Context, filter VisibilitySelectFilter) ([]VisibilityRow, error) GetFromVisibility(ctx context.Context, filter VisibilityGetFilter) (*VisibilityRow, error) DeleteFromVisibility(ctx context.Context, filter VisibilityDeleteFilter) (sql.Result, error) + CountFromVisibility(ctx context.Context, filter VisibilitySelectFilter) (int64, error) } ) diff --git a/common/persistence/visibility/store/sql/query_converter.go b/common/persistence/visibility/store/sql/query_converter.go index 547c877deea..3135f0a3bd1 100644 --- a/common/persistence/visibility/store/sql/query_converter.go +++ b/common/persistence/visibility/store/sql/query_converter.go @@ -36,7 +36,6 @@ import ( enumspb "go.temporal.io/api/enums/v1" "go.temporal.io/server/common/namespace" "go.temporal.io/server/common/persistence/sql/sqlplugin" - "go.temporal.io/server/common/persistence/visibility/manager" "go.temporal.io/server/common/persistence/visibility/store/query" "go.temporal.io/server/common/searchattribute" ) @@ -54,6 +53,8 @@ type ( token *pageToken, ) (string, []any) + buildCountStmt(namespaceID namespace.ID, queryString string) (string, []any) + getDatetimeFormat() string getCoalesceCloseTimeExpr() sqlparser.Expr @@ -61,9 +62,11 @@ type ( QueryConverter struct { pluginQueryConverter - request *manager.ListWorkflowExecutionsRequestV2 - saTypeMap searchattribute.NameTypeMap - saMapper searchattribute.Mapper + namespaceName namespace.Name + namespaceID namespace.ID + saTypeMap searchattribute.NameTypeMap + saMapper searchattribute.Mapper + queryString string seenNamespaceDivision bool } @@ -114,38 +117,57 @@ var ( func newQueryConverterInternal( pqc pluginQueryConverter, - request *manager.ListWorkflowExecutionsRequestV2, + namespaceName namespace.Name, + namespaceID namespace.ID, saTypeMap searchattribute.NameTypeMap, saMapper searchattribute.Mapper, + queryString string, ) *QueryConverter { return &QueryConverter{ pluginQueryConverter: pqc, - request: request, + namespaceName: namespaceName, + namespaceID: namespaceID, saTypeMap: saTypeMap, saMapper: saMapper, + queryString: queryString, seenNamespaceDivision: false, } } -func (c *QueryConverter) BuildSelectStmt() (*sqlplugin.VisibilitySelectFilter, error) { - token, err := deserializePageToken(c.request.NextPageToken) +func (c *QueryConverter) BuildSelectStmt( + pageSize int, + nextPageToken []byte, +) (*sqlplugin.VisibilitySelectFilter, error) { + token, err := deserializePageToken(nextPageToken) if err != nil { return nil, err } - queryString, err := c.convertWhereString(c.request.Query) + queryString, err := c.convertWhereString(c.queryString) if err != nil { return nil, err } queryString, queryArgs := c.buildSelectStmt( - c.request.NamespaceID, + c.namespaceID, queryString, - c.request.PageSize, + pageSize, token, ) return &sqlplugin.VisibilitySelectFilter{Query: queryString, QueryArgs: queryArgs}, nil } +func (c *QueryConverter) BuildCountStmt() (*sqlplugin.VisibilitySelectFilter, error) { + queryString, err := c.convertWhereString(c.queryString) + if err != nil { + return nil, err + } + queryString, queryArgs := c.buildCountStmt( + c.namespaceID, + queryString, + ) + return &sqlplugin.VisibilitySelectFilter{Query: queryString, QueryArgs: queryArgs}, nil +} + func (c *QueryConverter) convertWhereString(queryString string) (string, error) { where := strings.TrimSpace(queryString) if where != "" && !strings.HasPrefix(strings.ToLower(where), "order by") { @@ -376,7 +398,7 @@ func (c *QueryConverter) convertColName( saFieldName = saAlias if searchattribute.IsMappable(saAlias) { var err error - saFieldName, err = c.saMapper.GetFieldName(saAlias, c.request.Namespace.String()) + saFieldName, err = c.saMapper.GetFieldName(saAlias, c.namespaceName.String()) if err != nil { return "", "", err } diff --git a/common/persistence/visibility/store/sql/query_converter_factory.go b/common/persistence/visibility/store/sql/query_converter_factory.go index cb866001eef..299f41f130e 100644 --- a/common/persistence/visibility/store/sql/query_converter_factory.go +++ b/common/persistence/visibility/store/sql/query_converter_factory.go @@ -25,26 +25,28 @@ package sql import ( + "go.temporal.io/server/common/namespace" "go.temporal.io/server/common/persistence/sql/sqlplugin/mysql" "go.temporal.io/server/common/persistence/sql/sqlplugin/postgresql" "go.temporal.io/server/common/persistence/sql/sqlplugin/sqlite" - "go.temporal.io/server/common/persistence/visibility/manager" "go.temporal.io/server/common/searchattribute" ) func NewQueryConverter( pluginName string, - request *manager.ListWorkflowExecutionsRequestV2, + namespaceName namespace.Name, + namespaceID namespace.ID, saTypeMap searchattribute.NameTypeMap, saMapper searchattribute.Mapper, + queryString string, ) *QueryConverter { switch pluginName { case mysql.PluginNameV8: - return newMySQLQueryConverter(request, saTypeMap, saMapper) + return newMySQLQueryConverter(namespaceName, namespaceID, saTypeMap, saMapper, queryString) case postgresql.PluginNameV12: - return newPostgreSQLQueryConverter(request, saTypeMap, saMapper) + return newPostgreSQLQueryConverter(namespaceName, namespaceID, saTypeMap, saMapper, queryString) case sqlite.PluginName: - return newSqliteQueryConverter(request, saTypeMap, saMapper) + return newSqliteQueryConverter(namespaceName, namespaceID, saTypeMap, saMapper, queryString) default: return nil } diff --git a/common/persistence/visibility/store/sql/query_converter_mysql.go b/common/persistence/visibility/store/sql/query_converter_mysql.go index fd4785e9407..19105cc0b87 100644 --- a/common/persistence/visibility/store/sql/query_converter_mysql.go +++ b/common/persistence/visibility/store/sql/query_converter_mysql.go @@ -32,7 +32,6 @@ import ( "github.com/xwb1989/sqlparser" "go.temporal.io/server/common/namespace" "go.temporal.io/server/common/persistence/sql/sqlplugin" - "go.temporal.io/server/common/persistence/visibility/manager" "go.temporal.io/server/common/persistence/visibility/store/query" "go.temporal.io/server/common/searchattribute" ) @@ -83,15 +82,19 @@ func (node *jsonOverlapsExpr) Format(buf *sqlparser.TrackedBuffer) { } func newMySQLQueryConverter( - request *manager.ListWorkflowExecutionsRequestV2, + namespaceName namespace.Name, + namespaceID namespace.ID, saTypeMap searchattribute.NameTypeMap, saMapper searchattribute.Mapper, + queryString string, ) *QueryConverter { return newQueryConverterInternal( &mysqlQueryConverter{}, - request, + namespaceName, + namespaceID, saTypeMap, saMapper, + queryString, ) } @@ -253,3 +256,32 @@ func (c *mysqlQueryConverter) buildSelectStmt( searchattribute.GetSqlDbColName(searchattribute.RunID), ), queryArgs } + +func (c *mysqlQueryConverter) buildCountStmt( + namespaceID namespace.ID, + queryString string, +) (string, []any) { + var whereClauses []string + var queryArgs []any + + whereClauses = append( + whereClauses, + fmt.Sprintf("(%s = ?)", searchattribute.GetSqlDbColName(searchattribute.NamespaceID)), + ) + queryArgs = append(queryArgs, namespaceID.String()) + + if len(queryString) > 0 { + whereClauses = append(whereClauses, fmt.Sprintf("(%s)", queryString)) + } + + return fmt.Sprintf( + `SELECT COUNT(1) + FROM executions_visibility ev + LEFT JOIN custom_search_attributes + USING (%s, %s) + WHERE %s`, + searchattribute.GetSqlDbColName(searchattribute.NamespaceID), + searchattribute.GetSqlDbColName(searchattribute.RunID), + strings.Join(whereClauses, " AND "), + ), queryArgs +} diff --git a/common/persistence/visibility/store/sql/query_converter_postgresql.go b/common/persistence/visibility/store/sql/query_converter_postgresql.go index 2420a40e992..050e25cae4b 100644 --- a/common/persistence/visibility/store/sql/query_converter_postgresql.go +++ b/common/persistence/visibility/store/sql/query_converter_postgresql.go @@ -31,7 +31,6 @@ import ( "github.com/xwb1989/sqlparser" "go.temporal.io/server/common/namespace" "go.temporal.io/server/common/persistence/sql/sqlplugin" - "go.temporal.io/server/common/persistence/visibility/manager" "go.temporal.io/server/common/persistence/visibility/store/query" "go.temporal.io/server/common/searchattribute" ) @@ -64,15 +63,19 @@ func (node *pgCastExpr) Format(buf *sqlparser.TrackedBuffer) { } func newPostgreSQLQueryConverter( - request *manager.ListWorkflowExecutionsRequestV2, + namespaceName namespace.Name, + namespaceID namespace.ID, saTypeMap searchattribute.NameTypeMap, saMapper searchattribute.Mapper, + queryString string, ) *QueryConverter { return newQueryConverterInternal( &pgQueryConverter{}, - request, + namespaceName, + namespaceID, saTypeMap, saMapper, + queryString, ) } @@ -256,3 +259,26 @@ func (c *pgQueryConverter) buildSelectStmt( searchattribute.GetSqlDbColName(searchattribute.RunID), ), queryArgs } + +func (c *pgQueryConverter) buildCountStmt( + namespaceID namespace.ID, + queryString string, +) (string, []any) { + var whereClauses []string + var queryArgs []any + + whereClauses = append( + whereClauses, + fmt.Sprintf("(%s = ?)", searchattribute.GetSqlDbColName(searchattribute.NamespaceID)), + ) + queryArgs = append(queryArgs, namespaceID.String()) + + if len(queryString) > 0 { + whereClauses = append(whereClauses, fmt.Sprintf("(%s)", queryString)) + } + + return fmt.Sprintf( + "SELECT COUNT(1) FROM executions_visibility WHERE %s", + strings.Join(whereClauses, " AND "), + ), queryArgs +} diff --git a/common/persistence/visibility/store/sql/query_converter_sqlite.go b/common/persistence/visibility/store/sql/query_converter_sqlite.go index 9e6025b625a..821f53c1b90 100644 --- a/common/persistence/visibility/store/sql/query_converter_sqlite.go +++ b/common/persistence/visibility/store/sql/query_converter_sqlite.go @@ -31,7 +31,6 @@ import ( "github.com/xwb1989/sqlparser" "go.temporal.io/server/common/namespace" "go.temporal.io/server/common/persistence/sql/sqlplugin" - "go.temporal.io/server/common/persistence/visibility/manager" "go.temporal.io/server/common/persistence/visibility/store/query" "go.temporal.io/server/common/searchattribute" ) @@ -48,15 +47,19 @@ const ( ) func newSqliteQueryConverter( - request *manager.ListWorkflowExecutionsRequestV2, + namespaceName namespace.Name, + namespaceID namespace.ID, saTypeMap searchattribute.NameTypeMap, saMapper searchattribute.Mapper, + queryString string, ) *QueryConverter { return newQueryConverterInternal( &sqliteQueryConverter{}, - request, + namespaceName, + namespaceID, saTypeMap, saMapper, + queryString, ) } @@ -307,3 +310,26 @@ func (c *sqliteQueryConverter) buildFtsSelectStmt( ), } } + +func (c *sqliteQueryConverter) buildCountStmt( + namespaceID namespace.ID, + queryString string, +) (string, []any) { + var whereClauses []string + var queryArgs []any + + whereClauses = append( + whereClauses, + fmt.Sprintf("(%s = ?)", searchattribute.GetSqlDbColName(searchattribute.NamespaceID)), + ) + queryArgs = append(queryArgs, namespaceID.String()) + + if len(queryString) > 0 { + whereClauses = append(whereClauses, fmt.Sprintf("(%s)", queryString)) + } + + return fmt.Sprintf( + "SELECT COUNT(1) FROM executions_visibility WHERE %s", + strings.Join(whereClauses, " AND "), + ), queryArgs +} diff --git a/common/persistence/visibility/store/sql/visibility_store.go b/common/persistence/visibility/store/sql/visibility_store.go index bcbc859d809..84e4846a51d 100644 --- a/common/persistence/visibility/store/sql/visibility_store.go +++ b/common/persistence/visibility/store/sql/visibility_store.go @@ -364,8 +364,15 @@ func (s *VisibilityStore) ListWorkflowExecutions( return nil, err } - converter := NewQueryConverter(s.GetName(), request, saTypeMap, saMapper) - selectFilter, err := converter.BuildSelectStmt() + converter := NewQueryConverter( + s.GetName(), + request.Namespace, + request.NamespaceID, + saTypeMap, + saMapper, + request.Query, + ) + selectFilter, err := converter.BuildSelectStmt(request.PageSize, request.NextPageToken) if err != nil { return nil, err } @@ -417,10 +424,39 @@ func (s *VisibilityStore) ScanWorkflowExecutions( } func (s *VisibilityStore) CountWorkflowExecutions( - _ context.Context, - _ *manager.CountWorkflowExecutionsRequest, + ctx context.Context, + request *manager.CountWorkflowExecutionsRequest, ) (*manager.CountWorkflowExecutionsResponse, error) { - return nil, store.OperationNotSupportedErr + saTypeMap, err := s.searchAttributesProvider.GetSearchAttributes(s.GetIndexName(), false) + if err != nil { + return nil, err + } + + saMapper, err := s.searchAttributesMapperProvider.GetMapper(request.Namespace) + if err != nil { + return nil, err + } + + converter := NewQueryConverter( + s.GetName(), + request.Namespace, + request.NamespaceID, + saTypeMap, + saMapper, + request.Query, + ) + selectFilter, err := converter.BuildCountStmt() + if err != nil { + return nil, err + } + + count, err := s.sqlStore.Db.CountFromVisibility(ctx, *selectFilter) + if err != nil { + return nil, serviceerror.NewUnavailable( + fmt.Sprintf("CountWorkflowExecutions operation failed. Query failed: %v", err)) + } + + return &manager.CountWorkflowExecutionsResponse{Count: count}, nil } func (s *VisibilityStore) GetWorkflowExecution( diff --git a/tests/advanced_visibility_test.go b/tests/advanced_visibility_test.go index fa5b68cabc0..9ed11791fd4 100644 --- a/tests/advanced_visibility_test.go +++ b/tests/advanced_visibility_test.go @@ -1081,10 +1081,6 @@ func (s *advancedVisibilitySuite) TestScanWorkflow_PageToken() { } func (s *advancedVisibilitySuite) TestCountWorkflow() { - if !s.isElasticsearchEnabled { - s.T().Skip("This test is only for Elasticsearch") - } - id := "es-integration-count-workflow-test" wt := "es-integration-count-workflow-test-type" tl := "es-integration-count-workflow-test-taskqueue"