diff --git a/common/persistence/sql/sqlplugin/interfaces.go b/common/persistence/sql/sqlplugin/interfaces.go index a450bff837f..303e5916297 100644 --- a/common/persistence/sql/sqlplugin/interfaces.go +++ b/common/persistence/sql/sqlplugin/interfaces.go @@ -28,6 +28,7 @@ import ( "context" "database/sql" + "github.com/jmoiron/sqlx" "go.temporal.io/server/common/config" "go.temporal.io/server/common/resolver" ) @@ -131,5 +132,6 @@ type ( NamedExecContext(ctx context.Context, query string, arg interface{}) (sql.Result, error) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error + PrepareNamedContext(ctx context.Context, query string) (*sqlx.NamedStmt, error) } ) diff --git a/common/persistence/sql/sqlplugin/mysql/typeconv.go b/common/persistence/sql/sqlplugin/mysql/typeconv.go index 59b261804cc..bb5dba9b2d3 100644 --- a/common/persistence/sql/sqlplugin/mysql/typeconv.go +++ b/common/persistence/sql/sqlplugin/mysql/typeconv.go @@ -45,7 +45,7 @@ func (c *converter) ToMySQLDateTime(t time.Time) time.Time { if t.IsZero() { return minMySQLDateTime } - return t.UTC() + return t.UTC().Truncate(time.Microsecond) } // FromMySQLDateTime converts mysql datetime and returns go time diff --git a/common/persistence/sql/sqlplugin/mysql/visibility_v8.go b/common/persistence/sql/sqlplugin/mysql/visibility_v8.go new file mode 100644 index 00000000000..15987299f59 --- /dev/null +++ b/common/persistence/sql/sqlplugin/mysql/visibility_v8.go @@ -0,0 +1,281 @@ +// The MIT License +// +// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. +// +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package mysql + +import ( + "context" + "database/sql" + "errors" + "fmt" + "strings" + + "go.temporal.io/server/common/persistence/sql/sqlplugin" +) + +var ( + templateInsertWorkflowExecution = fmt.Sprintf( + `INSERT INTO executions_visibility (%s) + VALUES (%s) + ON DUPLICATE KEY UPDATE run_id = VALUES(run_id)`, + strings.Join(sqlplugin.DbFields, ", "), + sqlplugin.BuildNamedPlaceholder(sqlplugin.DbFields...), + ) + + templateInsertCustomSearchAttributes = ` + INSERT INTO custom_search_attributes ( + namespace_id, run_id, search_attributes + ) VALUES (:namespace_id, :run_id, :search_attributes) + ON DUPLICATE KEY UPDATE run_id = VALUES(run_id)` + + templateUpsertWorkflowExecution = fmt.Sprintf( + `INSERT INTO executions_visibility (%s) + VALUES (%s) + %s`, + strings.Join(sqlplugin.DbFields, ", "), + sqlplugin.BuildNamedPlaceholder(sqlplugin.DbFields...), + buildOnDuplicateKeyUpdate(sqlplugin.DbFields...), + ) + + templateUpsertCustomSearchAttributes = ` + INSERT INTO custom_search_attributes ( + namespace_id, run_id, search_attributes + ) VALUES (:namespace_id, :run_id, :search_attributes) + ON DUPLICATE KEY UPDATE search_attributes = VALUES(search_attributes)` + + templateDeleteWorkflowExecution_v8 = ` + DELETE FROM executions_visibility + WHERE namespace_id = :namespace_id AND run_id = :run_id` + + templateDeleteCustomSearchAttributes = ` + DELETE FROM custom_search_attributes + WHERE namespace_id = :namespace_id AND run_id = :run_id` + + templateGetWorkflowExecution_v8 = fmt.Sprintf( + `SELECT %s FROM executions_visibility + WHERE namespace_id = :namespace_id AND run_id = :run_id`, + strings.Join(sqlplugin.DbFields, ", "), + ) +) + +func buildOnDuplicateKeyUpdate(fields ...string) string { + items := make([]string, len(fields)) + for i, field := range fields { + items[i] = fmt.Sprintf("%s = VALUES(%s)", field, field) + } + return fmt.Sprintf("ON DUPLICATE KEY UPDATE %s", strings.Join(items, ", ")) +} + +// InsertIntoVisibility inserts a row into visibility table. If an row already exist, +// its left as such and no update will be made +func (mdb *dbV8) InsertIntoVisibility( + ctx context.Context, + row *sqlplugin.VisibilityRow, +) (result sql.Result, retError error) { + finalRow := mdb.prepareRowForDB(row) + tx, err := mdb.db.db.BeginTxx(ctx, nil) + if err != nil { + return nil, err + } + defer func() { + err := tx.Rollback() + // If the error is sql.ErrTxDone, it means the transaction already closed, so ignore error. + if err != nil && !errors.Is(err, sql.ErrTxDone) { + // Transaction rollback error should never happen, unless db connection was lost. + retError = fmt.Errorf("transaction rollback failed: %w", retError) + } + }() + result, err = tx.NamedExecContext(ctx, templateInsertWorkflowExecution, finalRow) + if err != nil { + return nil, err + } + _, err = tx.NamedExecContext(ctx, templateInsertCustomSearchAttributes, finalRow) + if err != nil { + return nil, err + } + err = tx.Commit() + if err != nil { + return nil, err + } + return result, nil +} + +// ReplaceIntoVisibility replaces an existing row if it exist or creates a new row in visibility table +func (mdb *dbV8) ReplaceIntoVisibility( + ctx context.Context, + row *sqlplugin.VisibilityRow, +) (result sql.Result, retError error) { + finalRow := mdb.prepareRowForDB(row) + tx, err := mdb.db.db.BeginTxx(ctx, nil) + if err != nil { + return nil, err + } + defer func() { + err := tx.Rollback() + // If the error is sql.ErrTxDone, it means the transaction already closed, so ignore error. + if err != nil && !errors.Is(err, sql.ErrTxDone) { + // Transaction rollback error should never happen, unless db connection was lost. + retError = fmt.Errorf("transaction rollback failed: %w", retError) + } + }() + result, err = tx.NamedExecContext(ctx, templateUpsertWorkflowExecution, finalRow) + if err != nil { + return nil, err + } + _, err = tx.NamedExecContext(ctx, templateUpsertCustomSearchAttributes, finalRow) + if err != nil { + return nil, err + } + err = tx.Commit() + if err != nil { + return nil, err + } + return result, nil +} + +// DeleteFromVisibility deletes a row from visibility table if it exist +func (mdb *dbV8) DeleteFromVisibility( + ctx context.Context, + filter sqlplugin.VisibilityDeleteFilter, +) (result sql.Result, retError error) { + tx, err := mdb.db.db.BeginTxx(ctx, nil) + if err != nil { + return nil, err + } + defer func() { + err := tx.Rollback() + // If the error is sql.ErrTxDone, it means the transaction already closed, so ignore error. + if err != nil && !errors.Is(err, sql.ErrTxDone) { + // Transaction rollback error should never happen, unless db connection was lost. + retError = fmt.Errorf("transaction rollback failed: %w", retError) + } + }() + _, err = mdb.conn.NamedExecContext(ctx, templateDeleteCustomSearchAttributes, filter) + if err != nil { + return nil, err + } + result, err = mdb.conn.NamedExecContext(ctx, templateDeleteWorkflowExecution_v8, filter) + if err != nil { + return nil, err + } + err = tx.Commit() + if err != nil { + return nil, err + } + return result, nil +} + +// SelectFromVisibility reads one or more rows from visibility table +func (mdb *dbV8) SelectFromVisibility( + ctx context.Context, + filter sqlplugin.VisibilitySelectFilter, +) ([]sqlplugin.VisibilityRow, error) { + if len(filter.Query) == 0 { + // backward compatibility for existing tests + err := sqlplugin.GenerateSelectQuery(&filter, mdb.converter.ToMySQLDateTime) + if err != nil { + return nil, err + } + } + + var rows []sqlplugin.VisibilityRow + err := mdb.conn.SelectContext(ctx, &rows, filter.Query, filter.QueryArgs...) + if err != nil { + return nil, err + } + for i := range rows { + err = mdb.processRowFromDB(&rows[i]) + if err != nil { + return nil, err + } + } + return rows, nil +} + +// GetFromVisibility reads one row from visibility table +func (mdb *dbV8) GetFromVisibility( + ctx context.Context, + filter sqlplugin.VisibilityGetFilter, +) (*sqlplugin.VisibilityRow, error) { + var row sqlplugin.VisibilityRow + stmt, err := mdb.conn.PrepareNamedContext(ctx, templateGetWorkflowExecution_v8) + if err != nil { + return nil, err + } + err = stmt.GetContext(ctx, &row, filter) + if err != nil { + return nil, err + } + err = mdb.processRowFromDB(&row) + if err != nil { + return nil, err + } + return &row, nil +} + +func (mdb *dbV8) prepareRowForDB(row *sqlplugin.VisibilityRow) *sqlplugin.VisibilityRow { + if row == nil { + return nil + } + finalRow := *row + finalRow.StartTime = mdb.converter.ToMySQLDateTime(finalRow.StartTime) + finalRow.ExecutionTime = mdb.converter.ToMySQLDateTime(finalRow.ExecutionTime) + if finalRow.CloseTime != nil { + *finalRow.CloseTime = mdb.converter.ToMySQLDateTime(*finalRow.CloseTime) + } + return &finalRow +} + +func (mdb *dbV8) processRowFromDB(row *sqlplugin.VisibilityRow) error { + if row == nil { + return nil + } + row.StartTime = mdb.converter.FromMySQLDateTime(row.StartTime) + row.ExecutionTime = mdb.converter.FromMySQLDateTime(row.ExecutionTime) + if row.CloseTime != nil { + closeTime := mdb.converter.FromMySQLDateTime(*row.CloseTime) + row.CloseTime = &closeTime + } + if row.SearchAttributes != nil { + for saName, saValue := range *row.SearchAttributes { + switch typedSaValue := saValue.(type) { + case []interface{}: + // the only valid type is slice of strings + strSlice := make([]string, len(typedSaValue)) + for i, item := range typedSaValue { + switch v := item.(type) { + case string: + strSlice[i] = v + default: + return fmt.Errorf("Unexpected data type in keyword list: %T (expected string)", v) + } + } + (*row.SearchAttributes)[saName] = strSlice + default: + // no-op + } + } + } + return nil +} diff --git a/common/persistence/sql/sqlplugin/tests/visibility.go b/common/persistence/sql/sqlplugin/tests/visibility.go index 39fd48f2b77..164f3376a2a 100644 --- a/common/persistence/sql/sqlplugin/tests/visibility.go +++ b/common/persistence/sql/sqlplugin/tests/visibility.go @@ -289,7 +289,7 @@ func (s *visibilitySuite) TestReplaceSelect_Exists() { s.Equal([]sqlplugin.VisibilityRow{visibility}, rows) } -func (s *visibilitySuite) TestDeleteSelect() { +func (s *visibilitySuite) TestDeleteGet() { namespaceID := primitives.NewUUID() runID := primitives.NewUUID() @@ -303,15 +303,15 @@ func (s *visibilitySuite) TestDeleteSelect() { s.NoError(err) s.Equal(0, int(rowsAffected)) - selectFilter := sqlplugin.VisibilitySelectFilter{ + getFilter := sqlplugin.VisibilityGetFilter{ NamespaceID: namespaceID.String(), - RunID: convert.StringPtr(runID.String()), + RunID: runID.String(), } - _, err = s.store.SelectFromVisibility(newVisibilityContext(), selectFilter) + _, err = s.store.GetFromVisibility(newVisibilityContext(), getFilter) s.Error(err) // TODO persistence layer should do proper error translation } -func (s *visibilitySuite) TestInsertDeleteSelect() { +func (s *visibilitySuite) TestInsertDeleteGet() { namespaceID := primitives.NewUUID() runID := primitives.NewUUID() workflowTypeName := shuffle.String(testVisibilityWorkflowTypeName) @@ -349,15 +349,15 @@ func (s *visibilitySuite) TestInsertDeleteSelect() { s.NoError(err) s.Equal(1, int(rowsAffected)) - selectFilter := sqlplugin.VisibilitySelectFilter{ + getFilter := sqlplugin.VisibilityGetFilter{ NamespaceID: namespaceID.String(), - RunID: convert.StringPtr(runID.String()), + RunID: runID.String(), } - _, err = s.store.SelectFromVisibility(newVisibilityContext(), selectFilter) + _, err = s.store.GetFromVisibility(newVisibilityContext(), getFilter) s.Error(err) // TODO persistence layer should do proper error translation } -func (s *visibilitySuite) TestReplaceDeleteSelect() { +func (s *visibilitySuite) TestReplaceDeleteGet() { namespaceID := primitives.NewUUID() runID := primitives.NewUUID() workflowTypeName := shuffle.String(testVisibilityWorkflowTypeName) @@ -395,11 +395,11 @@ func (s *visibilitySuite) TestReplaceDeleteSelect() { s.NoError(err) s.Equal(1, int(rowsAffected)) - selectFilter := sqlplugin.VisibilitySelectFilter{ + getFilter := sqlplugin.VisibilityGetFilter{ NamespaceID: namespaceID.String(), - RunID: convert.StringPtr(runID.String()), + RunID: runID.String(), } - _, err = s.store.SelectFromVisibility(newVisibilityContext(), selectFilter) + _, err = s.store.GetFromVisibility(newVisibilityContext(), getFilter) s.Error(err) // TODO persistence layer should do proper error translation } diff --git a/common/persistence/sql/sqlplugin/util.go b/common/persistence/sql/sqlplugin/util.go new file mode 100644 index 00000000000..7146e978a15 --- /dev/null +++ b/common/persistence/sql/sqlplugin/util.go @@ -0,0 +1,39 @@ +// The MIT License +// +// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. +// +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package sqlplugin + +import "strings" + +func appendPrefix(prefix string, fields []string) []string { + out := make([]string, len(fields)) + for i, field := range fields { + out[i] = prefix + field + } + return out +} + +func BuildNamedPlaceholder(fields ...string) string { + return strings.Join(appendPrefix(":", fields), ", ") +} diff --git a/common/persistence/sql/sqlplugin/visibility.go b/common/persistence/sql/sqlplugin/visibility.go index 86d89594930..f46973d1b4d 100644 --- a/common/persistence/sql/sqlplugin/visibility.go +++ b/common/persistence/sql/sqlplugin/visibility.go @@ -27,10 +27,24 @@ package sqlplugin import ( "context" "database/sql" + "database/sql/driver" + "encoding/json" + "fmt" + "reflect" + "strings" "time" + + "github.com/iancoleman/strcase" + + enumspb "go.temporal.io/api/enums/v1" + "go.temporal.io/server/common/searchattribute" ) type ( + // VisibilitySearchAttributes represents the search attributes json + // in executions_visibility table + VisibilitySearchAttributes map[string]interface{} + // VisibilityRow represents a row in executions_visibility table VisibilityRow struct { NamespaceID string @@ -45,6 +59,7 @@ type ( Memo []byte Encoding string TaskQueue string + SearchAttributes *VisibilitySearchAttributes } // VisibilitySelectFilter contains the column names within executions_visibility table that @@ -58,6 +73,9 @@ type ( MinTime *time.Time MaxTime *time.Time PageSize *int + + Query string + QueryArgs []interface{} } VisibilityGetFilter struct { @@ -89,3 +107,146 @@ type ( DeleteFromVisibility(ctx context.Context, filter VisibilityDeleteFilter) (sql.Result, error) } ) + +var _ sql.Scanner = (*VisibilitySearchAttributes)(nil) +var _ driver.Valuer = (*VisibilitySearchAttributes)(nil) + +var DbFields = getDbFields() + +func (vsa *VisibilitySearchAttributes) Scan(src interface{}) error { + if src == nil { + return nil + } + switch v := src.(type) { + case []byte: + return json.Unmarshal(v, &vsa) + case string: + return json.Unmarshal([]byte(v), &vsa) + default: + return fmt.Errorf("unsupported type for VisibilitySearchAttributes: %T", v) + } +} + +func (vsa VisibilitySearchAttributes) Value() (driver.Value, error) { + if vsa == nil { + return nil, nil + } + return json.Marshal(vsa) +} + +func getDbFields() []string { + t := reflect.TypeOf(VisibilityRow{}) + dbFields := make([]string, t.NumField()) + for i := 0; i < t.NumField(); i++ { + f := t.Field(i) + dbFields[i] = f.Tag.Get("db") + if dbFields[i] == "" { + dbFields[i] = strcase.ToSnake(f.Name) + } + } + return dbFields +} + +// TODO (rodrigozhou): deprecate with standard visibility code. +// GenerateSelectQuery generates the SELECT query based on the fields of VisibilitySelectFilter +// for backward compatibility of any use case using old format (eg: unit test). +// It will be removed after all use cases change to use query converter. +func GenerateSelectQuery( + filter *VisibilitySelectFilter, + convertToDbDateTime func(time.Time) time.Time, +) error { + whereClauses := make([]string, 0, 10) + queryArgs := make([]interface{}, 0, 10) + + whereClauses = append( + whereClauses, + fmt.Sprintf("%s = ?", searchattribute.GetSqlDbColName(searchattribute.NamespaceID)), + ) + queryArgs = append(queryArgs, filter.NamespaceID) + + if filter.WorkflowID != nil { + whereClauses = append( + whereClauses, + fmt.Sprintf("%s = ?", searchattribute.GetSqlDbColName(searchattribute.WorkflowID)), + ) + queryArgs = append(queryArgs, *filter.WorkflowID) + } + + if filter.WorkflowTypeName != nil { + whereClauses = append( + whereClauses, + fmt.Sprintf("%s = ?", searchattribute.GetSqlDbColName(searchattribute.WorkflowType)), + ) + queryArgs = append(queryArgs, *filter.WorkflowTypeName) + } + + timeAttr := searchattribute.StartTime + if filter.Status != int32(enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING) { + timeAttr = searchattribute.CloseTime + } + if filter.Status == int32(enumspb.WORKFLOW_EXECUTION_STATUS_UNSPECIFIED) { + whereClauses = append( + whereClauses, + fmt.Sprintf("%s != ?", searchattribute.GetSqlDbColName(searchattribute.ExecutionStatus)), + ) + queryArgs = append(queryArgs, int32(enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING)) + } else { + whereClauses = append( + whereClauses, + fmt.Sprintf("%s = ?", searchattribute.GetSqlDbColName(searchattribute.ExecutionStatus)), + ) + queryArgs = append(queryArgs, filter.Status) + } + + switch { + case filter.RunID != nil && filter.MinTime == nil && filter.Status != 1: + whereClauses = append( + whereClauses, + fmt.Sprintf("%s = ?", searchattribute.GetSqlDbColName(searchattribute.RunID)), + ) + queryArgs = append( + queryArgs, + *filter.RunID, + 1, // page size arg + ) + case filter.RunID != nil && filter.MinTime != nil && filter.MaxTime != nil && filter.PageSize != nil: + // pagination filters + *filter.MinTime = convertToDbDateTime(*filter.MinTime) + *filter.MaxTime = convertToDbDateTime(*filter.MaxTime) + whereClauses = append( + whereClauses, + fmt.Sprintf("%s >= ?", searchattribute.GetSqlDbColName(timeAttr)), + fmt.Sprintf("%s <= ?", searchattribute.GetSqlDbColName(timeAttr)), + fmt.Sprintf( + "((%s = ? AND %s > ?) OR %s < ?)", + searchattribute.GetSqlDbColName(timeAttr), + searchattribute.GetSqlDbColName(searchattribute.RunID), + searchattribute.GetSqlDbColName(timeAttr), + ), + ) + queryArgs = append( + queryArgs, + *filter.MinTime, + *filter.MaxTime, + *filter.MaxTime, + *filter.RunID, + *filter.MaxTime, + *filter.PageSize, + ) + default: + return fmt.Errorf("invalid query filter") + } + + filter.Query = fmt.Sprintf( + `SELECT %s FROM executions_visibility + WHERE %s + ORDER BY %s DESC, %s + LIMIT ?`, + strings.Join(DbFields, ", "), + strings.Join(whereClauses, " AND "), + searchattribute.GetSqlDbColName(timeAttr), + searchattribute.GetSqlDbColName(searchattribute.RunID), + ) + filter.QueryArgs = queryArgs + return nil +} diff --git a/common/persistence/tests/visibility_persistence_suite_test.go b/common/persistence/tests/visibility_persistence_suite_test.go index 9f98eaad2b2..25a340a2662 100644 --- a/common/persistence/tests/visibility_persistence_suite_test.go +++ b/common/persistence/tests/visibility_persistence_suite_test.go @@ -29,6 +29,7 @@ import ( "fmt" "time" + "github.com/golang/mock/gomock" "github.com/pborman/uuid" "github.com/stretchr/testify/require" commonpb "go.temporal.io/api/common/v1" @@ -56,9 +57,12 @@ type ( // override suite.Suite.Assertions with require.Assertions; this means that s.NotNil(nil) will stop the test, // not merely log an error *require.Assertions + controller *gomock.Controller persistencetests.TestBase - VisibilityMgr manager.VisibilityManager + VisibilityMgr manager.VisibilityManager + SearchAttributesProvider searchattribute.Provider + SearchAttributesMapperProvider searchattribute.MapperProvider ctx context.Context cancel context.CancelFunc @@ -71,13 +75,19 @@ func (s *VisibilityPersistenceSuite) SetupSuite() { cfg := s.DefaultTestCluster.Config() var err error + s.controller = gomock.NewController(s.T()) + s.SearchAttributesProvider = searchattribute.NewTestProvider() + s.SearchAttributesMapperProvider = searchattribute.NewTestMapperProvider(nil) s.VisibilityMgr, err = visibility.NewStandardManager( cfg, resolver.NewNoopResolver(), + s.SearchAttributesProvider, + s.SearchAttributesMapperProvider, dynamicconfig.GetIntPropertyFn(1000), dynamicconfig.GetIntPropertyFn(1000), metrics.NoopMetricsHandler, - s.Logger) + s.Logger, + ) if err != nil { // s.NoError doesn't work here. diff --git a/common/persistence/visibility/factory.go b/common/persistence/visibility/factory.go index 69da46227f4..52c66dc69f5 100644 --- a/common/persistence/visibility/factory.go +++ b/common/persistence/visibility/factory.go @@ -29,13 +29,15 @@ import ( "go.temporal.io/server/common/dynamicconfig" "go.temporal.io/server/common/log" "go.temporal.io/server/common/metrics" + "go.temporal.io/server/common/persistence/sql/sqlplugin/mysql" "go.temporal.io/server/common/persistence/visibility/manager" "go.temporal.io/server/common/persistence/visibility/store" "go.temporal.io/server/common/persistence/visibility/store/elasticsearch" esclient "go.temporal.io/server/common/persistence/visibility/store/elasticsearch/client" + "go.temporal.io/server/common/persistence/visibility/store/sql" "go.temporal.io/server/common/persistence/visibility/store/standard" "go.temporal.io/server/common/persistence/visibility/store/standard/cassandra" - "go.temporal.io/server/common/persistence/visibility/store/standard/sql" + standardSql "go.temporal.io/server/common/persistence/visibility/store/standard/sql" "go.temporal.io/server/common/resolver" "go.temporal.io/server/common/searchattribute" ) @@ -67,6 +69,8 @@ func NewManager( stdVisibilityManager, err := NewStandardManager( persistenceCfg, persistenceResolver, + searchAttributesProvider, + searchAttributesMapperProvider, standardVisibilityPersistenceMaxReadQPS, standardVisibilityPersistenceMaxWriteQPS, metricsHandler, @@ -157,6 +161,8 @@ func NewManager( func NewStandardManager( persistenceCfg config.Persistence, persistenceResolver resolver.ServiceResolver, + searchAttributesProvider searchattribute.Provider, + searchAttributesMapperProvider searchattribute.MapperProvider, standardVisibilityPersistenceMaxReadQPS dynamicconfig.IntPropertyFn, standardVisibilityPersistenceMaxWriteQPS dynamicconfig.IntPropertyFn, @@ -168,6 +174,8 @@ func NewStandardManager( stdVisibilityStore, err := newStandardVisibilityStore( persistenceCfg, persistenceResolver, + searchAttributesProvider, + searchAttributesMapperProvider, logger) if err != nil { return nil, err @@ -252,6 +260,8 @@ func newVisibilityManager( func newStandardVisibilityStore( persistenceCfg config.Persistence, persistenceResolver resolver.ServiceResolver, + searchAttributesProvider searchattribute.Provider, + searchAttributesMapperProvider searchattribute.MapperProvider, logger log.Logger, ) (store.VisibilityStore, error) { // If standard visibility is not configured. @@ -262,26 +272,44 @@ func newStandardVisibilityStore( visibilityStoreCfg := persistenceCfg.DataStores[persistenceCfg.VisibilityStore] var ( - store store.VisibilityStore - err error + visStore store.VisibilityStore + isStandard bool + err error ) switch { case visibilityStoreCfg.Cassandra != nil: - store, err = cassandra.NewVisibilityStore(*visibilityStoreCfg.Cassandra, persistenceResolver, logger) + visStore, err = cassandra.NewVisibilityStore(*visibilityStoreCfg.Cassandra, persistenceResolver, logger) + isStandard = true case visibilityStoreCfg.SQL != nil: - store, err = sql.NewSQLVisibilityStore(*visibilityStoreCfg.SQL, persistenceResolver, logger) + switch visibilityStoreCfg.SQL.PluginName { + case mysql.PluginNameV8: + isStandard = false + visStore, err = sql.NewSQLVisibilityStore( + *visibilityStoreCfg.SQL, + persistenceResolver, + searchAttributesProvider, + searchAttributesMapperProvider, + logger, + ) + default: + isStandard = true + visStore, err = standardSql.NewSQLVisibilityStore(*visibilityStoreCfg.SQL, persistenceResolver, logger) + } } if err != nil { return nil, err } - if store == nil { + if visStore == nil { logger.Fatal("invalid config: one of cassandra or sql params must be specified for visibility store") return nil, nil } - return standard.NewVisibilityStore(store), nil + if isStandard { + return standard.NewVisibilityStore(visStore), nil + } + return visStore, nil } func newAdvancedVisibilityStore( diff --git a/common/persistence/visibility/store/query/converter.go b/common/persistence/visibility/store/query/converter.go index 46dd6f41e90..0b109d3b5b7 100644 --- a/common/persistence/visibility/store/query/converter.go +++ b/common/persistence/visibility/store/query/converter.go @@ -357,11 +357,11 @@ func (r *rangeCondConverter) Convert(expr sqlparser.Expr) (elastic.Query, error) return nil, wrapConverterError("unable to convert left part of 'between' expression", err) } - fromValue, err := parseSqlValue(sqlparser.String(rangeCond.From)) + fromValue, err := ParseSqlValue(sqlparser.String(rangeCond.From)) if err != nil { return nil, err } - toValue, err := parseSqlValue(sqlparser.String(rangeCond.To)) + toValue, err := ParseSqlValue(sqlparser.String(rangeCond.To)) if err != nil { return nil, err } @@ -478,7 +478,7 @@ func (c *comparisonExprConverter) Convert(expr sqlparser.Expr) (elastic.Query, e func convertComparisonExprValue(expr sqlparser.Expr) (interface{}, error) { switch e := expr.(type) { case *sqlparser.SQLVal: - v, err := parseSqlValue(sqlparser.String(e)) + v, err := ParseSqlValue(sqlparser.String(e)) if err != nil { return nil, err } @@ -520,7 +520,7 @@ func (n *notSupportedExprConverter) Convert(expr sqlparser.Expr) (elastic.Query, return nil, NewConverterError("%s: expression of type %T", NotSupportedErrMessage, expr) } -func parseSqlValue(sqlValue string) (interface{}, error) { +func ParseSqlValue(sqlValue string) (interface{}, error) { if sqlValue == "" { return "", nil } diff --git a/common/persistence/visibility/store/sql/pagination_token.go b/common/persistence/visibility/store/sql/pagination_token.go new file mode 100644 index 00000000000..9a5e3ba8739 --- /dev/null +++ b/common/persistence/visibility/store/sql/pagination_token.go @@ -0,0 +1,52 @@ +// The MIT License +// +// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. +// +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package sql + +import ( + "encoding/json" + "time" +) + +type ( + pageToken struct { + CloseTime time.Time + StartTime time.Time + RunID string + } +) + +func deserializePageToken(data []byte) (*pageToken, error) { + if data == nil { + return nil, nil + } + var token *pageToken + err := json.Unmarshal(data, &token) + return token, err +} + +func serializePageToken(token *pageToken) ([]byte, error) { + data, err := json.Marshal(token) + return data, err +} diff --git a/common/persistence/visibility/store/sql/query_converter.go b/common/persistence/visibility/store/sql/query_converter.go new file mode 100644 index 00000000000..4446bdd7d13 --- /dev/null +++ b/common/persistence/visibility/store/sql/query_converter.go @@ -0,0 +1,543 @@ +// The MIT License +// +// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. +// +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package sql + +import ( + "errors" + "fmt" + "strconv" + "strings" + "time" + + "github.com/xwb1989/sqlparser" + + enumspb "go.temporal.io/api/enums/v1" + "go.temporal.io/server/common/namespace" + "go.temporal.io/server/common/persistence/sql/sqlplugin" + "go.temporal.io/server/common/persistence/visibility/manager" + "go.temporal.io/server/common/persistence/visibility/store/query" + "go.temporal.io/server/common/searchattribute" +) + +type ( + pluginQueryConverter interface { + convertKeywordListComparisonExpr(expr *sqlparser.ComparisonExpr) (sqlparser.Expr, error) + + convertTextComparisonExpr(expr *sqlparser.ComparisonExpr) (sqlparser.Expr, error) + + buildSelectStmt( + namespaceID namespace.ID, + queryString string, + pageSize int, + token *pageToken, + ) (string, []any) + + getDatetimeFormat() string + } + + QueryConverter struct { + pluginQueryConverter + request *manager.ListWorkflowExecutionsRequestV2 + saTypeMap searchattribute.NameTypeMap + saMapper searchattribute.Mapper + } +) + +var ( + // strings.Replacer takes a sequence of old to new replacements + escapeCharMap = []string{ + "'", "''", + `"`, `\\"`, + "\b", "\\b", + "\n", "\\n", + "\r", "\\r", + "\t", "\\t", + "\\", "\\\\", + } + + supportedComparisonOperators = []string{ + sqlparser.EqualStr, + sqlparser.NotEqualStr, + sqlparser.LessThanStr, + sqlparser.GreaterThanStr, + sqlparser.LessEqualStr, + sqlparser.GreaterEqualStr, + sqlparser.InStr, + sqlparser.NotInStr, + } + + supportedKeyworkListOperators = []string{ + sqlparser.EqualStr, + sqlparser.NotEqualStr, + sqlparser.InStr, + sqlparser.NotInStr, + } + + supportedTextOperators = []string{ + sqlparser.EqualStr, + sqlparser.NotEqualStr, + } + + supportedTypesRangeCond = []enumspb.IndexedValueType{ + enumspb.INDEXED_VALUE_TYPE_DATETIME, + enumspb.INDEXED_VALUE_TYPE_DOUBLE, + enumspb.INDEXED_VALUE_TYPE_INT, + enumspb.INDEXED_VALUE_TYPE_KEYWORD, + } +) + +func newQueryConverterInternal( + pqc pluginQueryConverter, + request *manager.ListWorkflowExecutionsRequestV2, + saTypeMap searchattribute.NameTypeMap, + saMapper searchattribute.Mapper, +) *QueryConverter { + return &QueryConverter{ + pluginQueryConverter: pqc, + request: request, + saTypeMap: saTypeMap, + saMapper: saMapper, + } +} + +func (c *QueryConverter) BuildSelectStmt() (*sqlplugin.VisibilitySelectFilter, error) { + token, err := deserializePageToken(c.request.NextPageToken) + if err != nil { + return nil, err + } + queryString, err := c.convertWhereString(c.request.Query) + if err != nil { + return nil, err + } + queryString, queryArgs := c.buildSelectStmt( + c.request.NamespaceID, + queryString, + c.request.PageSize, + token, + ) + return &sqlplugin.VisibilitySelectFilter{Query: queryString, QueryArgs: queryArgs}, nil +} + +func (c *QueryConverter) convertWhereString(queryString string) (string, error) { + where := strings.TrimSpace(queryString) + if where != "" && !strings.HasPrefix(strings.ToLower(where), "order by") { + where = "where " + where + } + // sqlparser can't parse just WHERE clause but instead accepts only valid SQL statement. + sql := "select * from table1 " + where + stmt, err := sqlparser.Parse(sql) + if err != nil { + return "", err + } + + selectStmt, _ := stmt.(*sqlparser.Select) + err = c.convertSelectStmt(selectStmt) + if err != nil { + return "", err + } + + result := "" + if selectStmt.Where != nil { + result = sqlparser.String(selectStmt.Where.Expr) + } + return result, nil +} + +func (c *QueryConverter) convertSelectStmt(sel *sqlparser.Select) error { + if sel.GroupBy != nil { + return query.NewConverterError("%s: 'group by' clause", query.NotSupportedErrMessage) + } + + if sel.OrderBy != nil { + return query.NewConverterError("%s: 'order by' clause", query.NotSupportedErrMessage) + } + + if sel.Limit != nil { + return query.NewConverterError("%s: 'limit' clause", query.NotSupportedErrMessage) + } + + if sel.Where != nil { + return c.convertWhereExpr(&sel.Where.Expr) + } + + return nil +} + +func (c *QueryConverter) convertWhereExpr(expr *sqlparser.Expr) error { + if expr == nil || *expr == nil { + return errors.New("cannot be nil") + } + + switch e := (*expr).(type) { + case *sqlparser.ParenExpr: + return c.convertWhereExpr(&e.Expr) + case *sqlparser.NotExpr: + return c.convertWhereExpr(&e.Expr) + case *sqlparser.AndExpr: + return c.convertAndExpr(expr) + case *sqlparser.OrExpr: + return c.convertOrExpr(expr) + case *sqlparser.ComparisonExpr: + return c.convertComparisonExpr(expr) + case *sqlparser.RangeCond: + return c.convertRangeCond(expr) + case *sqlparser.IsExpr: + return c.convertIsExpr(expr) + case *sqlparser.FuncExpr: + return query.NewConverterError("%s: function expression", query.NotSupportedErrMessage) + case *sqlparser.ColName: + return query.NewConverterError("incomplete expression") + default: + return query.NewConverterError("%s: expression of type %T", query.NotSupportedErrMessage, e) + } +} + +func (c *QueryConverter) convertAndExpr(exprRef *sqlparser.Expr) error { + expr, ok := (*exprRef).(*sqlparser.AndExpr) + if !ok { + return query.NewConverterError("%v is not an 'and' expression", sqlparser.String(*exprRef)) + } + err := c.convertWhereExpr(&expr.Left) + if err != nil { + return err + } + return c.convertWhereExpr(&expr.Right) +} + +func (c *QueryConverter) convertOrExpr(exprRef *sqlparser.Expr) error { + expr, ok := (*exprRef).(*sqlparser.OrExpr) + if !ok { + return query.NewConverterError("%v is not an 'or' expression", sqlparser.String(*exprRef)) + } + err := c.convertWhereExpr(&expr.Left) + if err != nil { + return err + } + return c.convertWhereExpr(&expr.Right) +} + +func (c *QueryConverter) convertComparisonExpr(exprRef *sqlparser.Expr) error { + expr, ok := (*exprRef).(*sqlparser.ComparisonExpr) + if !ok { + return query.NewConverterError("%v is not a comparison expression", sqlparser.String(*exprRef)) + } + + saName, saFieldName, err := c.convertColName(&expr.Left) + if err != nil { + return err + } + saType, err := c.saTypeMap.GetType(saFieldName) + if err != nil { + return query.NewConverterError( + "%s: column name '%s' is not a valid search attribute", + query.InvalidExpressionErrMessage, + saName, + ) + } + + if !isSupportedComparisonOperator(expr.Operator) { + return query.NewConverterError( + "%s: invalid operator '%s'", + query.InvalidExpressionErrMessage, + expr.Operator, + ) + } + + err = c.convertValueExpr(&expr.Right, saName, saType) + if err != nil { + return err + } + switch saType { + case enumspb.INDEXED_VALUE_TYPE_KEYWORD_LIST: + newExpr, err := c.convertKeywordListComparisonExpr(expr) + if err != nil { + return err + } + *exprRef = newExpr + case enumspb.INDEXED_VALUE_TYPE_TEXT: + newExpr, err := c.convertTextComparisonExpr(expr) + if err != nil { + return err + } + *exprRef = newExpr + } + return nil +} + +func (c *QueryConverter) convertRangeCond(exprRef *sqlparser.Expr) error { + expr, ok := (*exprRef).(*sqlparser.RangeCond) + if !ok { + return query.NewConverterError( + "%v is not a range condition expression", + sqlparser.String(*exprRef), + ) + } + saName, saFieldName, err := c.convertColName(&expr.Left) + if err != nil { + return err + } + saType, err := c.saTypeMap.GetType(saFieldName) + if err != nil { + return query.NewConverterError( + "%s: column name '%s' is not a valid search attribute", + query.InvalidExpressionErrMessage, + saName, + ) + } + if !isSupportedTypeRangeCond(saType) { + return query.NewConverterError( + "%s: cannot do range condition on search attribute '%s' of type %s", + query.InvalidExpressionErrMessage, + saName, + saType.String(), + ) + } + err = c.convertValueExpr(&expr.From, saName, saType) + if err != nil { + return err + } + err = c.convertValueExpr(&expr.To, saName, saType) + if err != nil { + return err + } + return nil +} + +func (c *QueryConverter) convertColName( + exprRef *sqlparser.Expr, +) (saAlias string, saFieldName string, retError error) { + expr, ok := (*exprRef).(*sqlparser.ColName) + if !ok { + return "", "", query.NewConverterError( + "%s: must be a column name but was %T", + query.InvalidExpressionErrMessage, + *exprRef, + ) + } + saAlias = strings.ReplaceAll(sqlparser.String(expr), "`", "") + saFieldName = saAlias + if searchattribute.IsMappable(saAlias) { + var err error + saFieldName, err = c.saMapper.GetFieldName(saAlias, c.request.Namespace.String()) + if err != nil { + return "", "", err + } + } + var newExpr sqlparser.Expr = newColName(searchattribute.GetSqlDbColName(saFieldName)) + if saAlias == searchattribute.CloseTime { + newExpr = getCoalesceCloseTimeExpr(c.getDatetimeFormat()) + } + *exprRef = newExpr + return saAlias, saFieldName, nil +} + +func (c *QueryConverter) convertValueExpr( + exprRef *sqlparser.Expr, + saName string, + saType enumspb.IndexedValueType, +) error { + expr := *exprRef + switch e := expr.(type) { + case *sqlparser.SQLVal: + value, err := c.parseSQLVal(e, saName, saType) + if err != nil { + return err + } + switch v := value.(type) { + case string: + // escape strings for safety + replacer := strings.NewReplacer(escapeCharMap...) + *exprRef = newUnsafeSQLString(replacer.Replace(v)) + case int64: + *exprRef = sqlparser.NewIntVal([]byte(strconv.FormatInt(v, 10))) + case float64: + *exprRef = sqlparser.NewFloatVal([]byte(strconv.FormatFloat(v, 'f', -1, 64))) + default: + // this should never happen: query.ParseSqlValue returns one of the types above + panic(fmt.Sprintf("Unexpected value type: %T", v)) + } + return nil + case sqlparser.BoolVal: + // no-op: no validation needed + return nil + case sqlparser.ValTuple: + // This is "in (1,2,3)" case. + for _, subExpr := range e { + err := c.convertValueExpr(&subExpr, saName, saType) + if err != nil { + return err + } + } + return nil + case *sqlparser.GroupConcatExpr: + return query.NewConverterError("%s: 'group_concat'", query.NotSupportedErrMessage) + case *sqlparser.FuncExpr: + return query.NewConverterError("%s: nested func", query.NotSupportedErrMessage) + case *sqlparser.ColName: + return query.NewConverterError( + "%s: column name on the right side of comparison expression (did you forget to quote '%s'?)", + query.NotSupportedErrMessage, + sqlparser.String(expr), + ) + default: + return query.NewConverterError( + "%s: unexpected value type %T", + query.InvalidExpressionErrMessage, + expr, + ) + } +} + +// parseSQLVal handles values for specific search attributes. +// For datetime, converts to UTC. +// For execution status, converts string to enum value. +func (c *QueryConverter) parseSQLVal( + expr *sqlparser.SQLVal, + saName string, + saType enumspb.IndexedValueType, +) (any, error) { + // Using expr.Val instead of sqlparser.String(expr) because the latter escapes chars using MySQL + // conventions which is incompatible with SQLite. + var sqlValue string + switch expr.Type { + case sqlparser.StrVal: + sqlValue = fmt.Sprintf(`'%s'`, expr.Val) + default: + sqlValue = string(expr.Val) + } + value, err := query.ParseSqlValue(sqlValue) + if err != nil { + return nil, err + } + + if saType == enumspb.INDEXED_VALUE_TYPE_DATETIME { + var tm time.Time + switch v := value.(type) { + case int64: + tm = time.Unix(0, v) + case string: + var err error + tm, err = time.Parse(time.RFC3339Nano, v) + if err != nil { + return "", err + } + default: + return "", query.NewConverterError( + "%s: unexpected value type %T for search attribute %s", + query.InvalidExpressionErrMessage, + v, + saName, + ) + } + return tm.UTC().Format(c.getDatetimeFormat()), nil + } + + if saName == searchattribute.ExecutionStatus { + var status int64 + switch v := value.(type) { + case int64: + status = v + case string: + code, ok := enumspb.WorkflowExecutionStatus_value[v] + if !ok { + return nil, query.NewConverterError( + "%s: invalid execution status value '%s'", + query.InvalidExpressionErrMessage, + v, + ) + } + status = int64(code) + default: + return "", query.NewConverterError( + "%s: unexpected value type %T for search attribute %s", + query.InvalidExpressionErrMessage, + v, + saName, + ) + } + return status, nil + } + + return value, nil +} + +func (c *QueryConverter) convertIsExpr(exprRef *sqlparser.Expr) error { + expr, ok := (*exprRef).(*sqlparser.IsExpr) + if !ok { + return query.NewConverterError("%v is not an 'IS' expression", sqlparser.String(*exprRef)) + } + saName, saFieldName, err := c.convertColName(&expr.Expr) + if err != nil { + return err + } + _, err = c.saTypeMap.GetType(saFieldName) + if err != nil { + return query.NewConverterError( + "%s: column name '%s' is not a valid search attribute", + query.InvalidExpressionErrMessage, + saName, + ) + } + switch expr.Operator { + case sqlparser.IsNullStr, sqlparser.IsNotNullStr: + // no-op + default: + return query.NewConverterError( + "%s: 'IS' operator can only be used with 'NULL' or 'NOT NULL'", + query.InvalidExpressionErrMessage, + ) + } + return nil +} + +func isSupportedOperator(supportedOperators []string, operator string) bool { + for _, op := range supportedOperators { + if operator == op { + return true + } + } + return false +} + +func isSupportedComparisonOperator(operator string) bool { + return isSupportedOperator(supportedComparisonOperators, operator) +} + +func isSupportedKeywordListOperator(operator string) bool { + return isSupportedOperator(supportedKeyworkListOperators, operator) +} + +func isSupportedTextOperator(operator string) bool { + return isSupportedOperator(supportedTextOperators, operator) +} + +func isSupportedTypeRangeCond(saType enumspb.IndexedValueType) bool { + for _, tp := range supportedTypesRangeCond { + if saType == tp { + return true + } + } + return false +} diff --git a/common/persistence/visibility/store/sql/query_converter_factory.go b/common/persistence/visibility/store/sql/query_converter_factory.go new file mode 100644 index 00000000000..b7fe609b78c --- /dev/null +++ b/common/persistence/visibility/store/sql/query_converter_factory.go @@ -0,0 +1,45 @@ +// The MIT License +// +// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. +// +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package sql + +import ( + "go.temporal.io/server/common/persistence/sql/sqlplugin/mysql" + "go.temporal.io/server/common/persistence/visibility/manager" + "go.temporal.io/server/common/searchattribute" +) + +func NewQueryConverter( + pluginName string, + request *manager.ListWorkflowExecutionsRequestV2, + saTypeMap searchattribute.NameTypeMap, + saMapper searchattribute.Mapper, +) *QueryConverter { + switch pluginName { + case mysql.PluginNameV8: + return newMySQLQueryConverter(request, saTypeMap, saMapper) + default: + return nil + } +} diff --git a/common/persistence/visibility/store/sql/query_converter_mysql.go b/common/persistence/visibility/store/sql/query_converter_mysql.go new file mode 100644 index 00000000000..d278d807b39 --- /dev/null +++ b/common/persistence/visibility/store/sql/query_converter_mysql.go @@ -0,0 +1,239 @@ +// The MIT License +// +// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. +// +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package sql + +import ( + "encoding/json" + "fmt" + "strings" + + "github.com/xwb1989/sqlparser" + "go.temporal.io/server/common/namespace" + "go.temporal.io/server/common/persistence/sql/sqlplugin" + "go.temporal.io/server/common/persistence/visibility/manager" + "go.temporal.io/server/common/persistence/visibility/store/query" + "go.temporal.io/server/common/searchattribute" +) + +type ( + castExpr struct { + sqlparser.Expr + Value sqlparser.Expr + Type *sqlparser.ConvertType + } + + memberOfExpr struct { + sqlparser.Expr + Value sqlparser.Expr + JSONArr sqlparser.Expr + } + + jsonOverlapsExpr struct { + sqlparser.Expr + JSONDoc1 sqlparser.Expr + JSONDoc2 sqlparser.Expr + } + + mysqlQueryConverter struct{} +) + +var convertTypeJSON = &sqlparser.ConvertType{Type: "json"} + +var _ sqlparser.Expr = (*castExpr)(nil) +var _ sqlparser.Expr = (*memberOfExpr)(nil) +var _ sqlparser.Expr = (*jsonOverlapsExpr)(nil) + +var _ pluginQueryConverter = (*mysqlQueryConverter)(nil) + +func (node *castExpr) Format(buf *sqlparser.TrackedBuffer) { + buf.Myprintf("cast(%v as %v)", node.Value, node.Type) +} + +func (node *memberOfExpr) Format(buf *sqlparser.TrackedBuffer) { + buf.Myprintf("%v member of (%v)", node.Value, node.JSONArr) +} + +func (node *jsonOverlapsExpr) Format(buf *sqlparser.TrackedBuffer) { + buf.Myprintf("json_overlaps(%v, %v)", node.JSONDoc1, node.JSONDoc2) +} + +func newMySQLQueryConverter( + request *manager.ListWorkflowExecutionsRequestV2, + saTypeMap searchattribute.NameTypeMap, + saMapper searchattribute.Mapper, +) *QueryConverter { + return newQueryConverterInternal( + &mysqlQueryConverter{}, + request, + saTypeMap, + saMapper, + ) +} + +func (c *mysqlQueryConverter) getDatetimeFormat() string { + return "2006-01-02 15:04:05.999999" +} + +func (c *mysqlQueryConverter) convertKeywordListComparisonExpr( + expr *sqlparser.ComparisonExpr, +) (sqlparser.Expr, error) { + if !isSupportedKeywordListOperator(expr.Operator) { + return nil, query.NewConverterError("invalid query") + } + + switch expr.Operator { + case sqlparser.EqualStr: + return &memberOfExpr{ + Value: expr.Right, + JSONArr: expr.Left, + }, nil + case sqlparser.NotEqualStr: + return &sqlparser.NotExpr{ + Expr: &memberOfExpr{ + Value: expr.Right, + JSONArr: expr.Left, + }, + }, nil + case sqlparser.InStr: + return c.convertToJsonOverlapsExpr(expr) + case sqlparser.NotInStr: + jsonOverlapsExpr, err := c.convertToJsonOverlapsExpr(expr) + if err != nil { + return nil, err + } + return &sqlparser.NotExpr{Expr: jsonOverlapsExpr}, nil + default: + // this should never happen since isSupportedKeywordListOperator should already fail + return nil, query.NewConverterError("invalid query") + } +} + +func (c *mysqlQueryConverter) convertToJsonOverlapsExpr( + expr *sqlparser.ComparisonExpr, +) (*jsonOverlapsExpr, error) { + valTuple, isValTuple := expr.Right.(sqlparser.ValTuple) + if !isValTuple { + return nil, query.NewConverterError("invalid query") + } + values := make([]any, len(valTuple)) + for i, val := range valTuple { + value, err := query.ParseSqlValue(sqlparser.String(val)) + if err != nil { + return nil, err + } + values[i] = value + } + jsonValue, err := json.Marshal(values) + if err != nil { + return nil, err + } + return &jsonOverlapsExpr{ + JSONDoc1: expr.Left, + JSONDoc2: &castExpr{ + Value: sqlparser.NewStrVal(jsonValue), + Type: convertTypeJSON, + }, + }, nil +} + +func (c *mysqlQueryConverter) convertTextComparisonExpr( + expr *sqlparser.ComparisonExpr, +) (sqlparser.Expr, error) { + if !isSupportedTextOperator(expr.Operator) { + return nil, query.NewConverterError("invalid query") + } + var newExpr sqlparser.Expr = &sqlparser.MatchExpr{ + Columns: []sqlparser.SelectExpr{&sqlparser.AliasedExpr{Expr: expr.Left}}, + Expr: expr.Right, + Option: sqlparser.NaturalLanguageModeStr, + } + if expr.Operator == sqlparser.NotEqualStr { + newExpr = &sqlparser.NotExpr{Expr: newExpr} + } + return newExpr, nil +} + +func (c *mysqlQueryConverter) buildSelectStmt( + namespaceID namespace.ID, + queryString string, + pageSize int, + token *pageToken, +) (string, []any) { + whereClauses := make([]string, 0, 3) + queryArgs := make([]any, 0, 8) + + whereClauses = append( + whereClauses, + fmt.Sprintf("%s = ?", searchattribute.GetSqlDbColName(searchattribute.NamespaceID)), + ) + queryArgs = append(queryArgs, namespaceID.String()) + + if len(queryString) > 0 { + whereClauses = append(whereClauses, queryString) + } + + if token != nil { + whereClauses = append( + whereClauses, + fmt.Sprintf( + "((%s = ? AND %s = ? AND %s > ?) OR (%s = ? AND %s < ?) OR %s < ?)", + sqlparser.String(getCoalesceCloseTimeExpr(c.getDatetimeFormat())), + searchattribute.GetSqlDbColName(searchattribute.StartTime), + searchattribute.GetSqlDbColName(searchattribute.RunID), + sqlparser.String(getCoalesceCloseTimeExpr(c.getDatetimeFormat())), + searchattribute.GetSqlDbColName(searchattribute.StartTime), + sqlparser.String(getCoalesceCloseTimeExpr(c.getDatetimeFormat())), + ), + ) + queryArgs = append( + queryArgs, + token.CloseTime, + token.StartTime, + token.RunID, + token.CloseTime, + token.StartTime, + token.CloseTime, + ) + } + + queryArgs = append(queryArgs, pageSize) + + return fmt.Sprintf( + `SELECT %s + FROM executions_visibility ev + INNER JOIN custom_search_attributes + USING (%s, %s) + WHERE %s + ORDER BY %s DESC, %s DESC, %s + LIMIT ?`, + strings.Join(addPrefix("ev.", sqlplugin.DbFields), ", "), + searchattribute.GetSqlDbColName(searchattribute.NamespaceID), + searchattribute.GetSqlDbColName(searchattribute.RunID), + strings.Join(whereClauses, " AND "), + sqlparser.String(getCoalesceCloseTimeExpr(c.getDatetimeFormat())), + searchattribute.GetSqlDbColName(searchattribute.StartTime), + searchattribute.GetSqlDbColName(searchattribute.RunID), + ), queryArgs +} diff --git a/common/persistence/visibility/store/sql/query_converter_util.go b/common/persistence/visibility/store/sql/query_converter_util.go new file mode 100644 index 00000000000..dce477f1810 --- /dev/null +++ b/common/persistence/visibility/store/sql/query_converter_util.go @@ -0,0 +1,105 @@ +// The MIT License +// +// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. +// +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package sql + +import ( + "time" + + "github.com/xwb1989/sqlparser" + "go.temporal.io/server/common/searchattribute" +) + +type ( + // unsafeSQLString don't escape the string value; unlike sqlparser.SQLVal. + // This is used for building string known to be safe. + unsafeSQLString struct { + sqlparser.Expr + Val string + } + + colName struct { + sqlparser.Expr + Name string + } +) + +const ( + coalesceFuncName = "coalesce" +) + +var _ sqlparser.Expr = (*unsafeSQLString)(nil) +var _ sqlparser.Expr = (*colName)(nil) + +var ( + maxDatetimeValue = getMaxDatetimeValue() +) + +func (node *unsafeSQLString) Format(buf *sqlparser.TrackedBuffer) { + buf.Myprintf("'%s'", node.Val) +} + +func (node *colName) Format(buf *sqlparser.TrackedBuffer) { + buf.Myprintf("%s", node.Name) +} + +func newUnsafeSQLString(val string) *unsafeSQLString { + return &unsafeSQLString{Val: val} +} + +func newColName(name string) *colName { + return &colName{Name: name} +} + +func newFuncExpr(name string, exprs ...sqlparser.Expr) *sqlparser.FuncExpr { + args := make([]sqlparser.SelectExpr, len(exprs)) + for i := range exprs { + args[i] = &sqlparser.AliasedExpr{Expr: exprs[i]} + } + return &sqlparser.FuncExpr{ + Name: sqlparser.NewColIdent(name), + Exprs: args, + } +} + +func addPrefix(prefix string, fields []string) []string { + out := make([]string, len(fields)) + for i, field := range fields { + out[i] = prefix + field + } + return out +} + +func getMaxDatetimeValue() time.Time { + t, _ := time.Parse(time.RFC3339, "9999-12-31T23:59:59Z") + return t +} + +func getCoalesceCloseTimeExpr(format string) sqlparser.Expr { + return newFuncExpr( + coalesceFuncName, + newColName(searchattribute.GetSqlDbColName(searchattribute.CloseTime)), + newUnsafeSQLString(maxDatetimeValue.Format(format)), + ) +} diff --git a/common/persistence/visibility/store/sql/visibility_store.go b/common/persistence/visibility/store/sql/visibility_store.go new file mode 100644 index 00000000000..79b7fcb9066 --- /dev/null +++ b/common/persistence/visibility/store/sql/visibility_store.go @@ -0,0 +1,585 @@ +// The MIT License +// +// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. +// +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package sql + +import ( + "context" + "fmt" + "time" + + "go.temporal.io/api/common/v1" + enumspb "go.temporal.io/api/enums/v1" + "go.temporal.io/api/serviceerror" + + "go.temporal.io/server/common/config" + "go.temporal.io/server/common/log" + "go.temporal.io/server/common/namespace" + "go.temporal.io/server/common/persistence" + persistencesql "go.temporal.io/server/common/persistence/sql" + "go.temporal.io/server/common/persistence/sql/sqlplugin" + "go.temporal.io/server/common/persistence/visibility/manager" + "go.temporal.io/server/common/persistence/visibility/store" + "go.temporal.io/server/common/resolver" + "go.temporal.io/server/common/searchattribute" +) + +type ( + VisibilityStore struct { + sqlStore persistencesql.SqlStore + searchAttributesProvider searchattribute.Provider + searchAttributesMapperProvider searchattribute.MapperProvider + } +) + +var _ store.VisibilityStore = (*VisibilityStore)(nil) + +var maxTime, _ = time.Parse(time.RFC3339, "9999-12-31T23:59:59Z") + +// NewSQLVisibilityStore creates an instance of VisibilityStore +func NewSQLVisibilityStore( + cfg config.SQL, + r resolver.ServiceResolver, + searchAttributesProvider searchattribute.Provider, + searchAttributesMapperProvider searchattribute.MapperProvider, + logger log.Logger, +) (*VisibilityStore, error) { + refDbConn := persistencesql.NewRefCountedDBConn(sqlplugin.DbKindVisibility, &cfg, r) + db, err := refDbConn.Get() + if err != nil { + return nil, err + } + return &VisibilityStore{ + sqlStore: persistencesql.NewSqlStore(db, logger), + searchAttributesProvider: searchAttributesProvider, + searchAttributesMapperProvider: searchAttributesMapperProvider, + }, nil +} + +func (s *VisibilityStore) Close() { + s.sqlStore.Close() +} + +func (s *VisibilityStore) GetName() string { + return s.sqlStore.GetName() +} + +func (s *VisibilityStore) GetIndexName() string { + return s.sqlStore.GetDbName() +} + +func (s *VisibilityStore) RecordWorkflowExecutionStarted( + ctx context.Context, + request *store.InternalRecordWorkflowExecutionStartedRequest, +) error { + searchAttributes, err := s.prepareSearchAttributesForDb(request.InternalVisibilityRequestBase) + if err != nil { + return err + } + _, err = s.sqlStore.Db.InsertIntoVisibility(ctx, &sqlplugin.VisibilityRow{ + NamespaceID: request.NamespaceID, + WorkflowID: request.WorkflowID, + RunID: request.RunID, + StartTime: request.StartTime, + ExecutionTime: request.ExecutionTime, + WorkflowTypeName: request.WorkflowTypeName, + Status: int32(enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING), + Memo: request.Memo.Data, + Encoding: request.Memo.EncodingType.String(), + TaskQueue: request.TaskQueue, + SearchAttributes: searchAttributes, + }) + + return err +} + +func (s *VisibilityStore) RecordWorkflowExecutionClosed( + ctx context.Context, + request *store.InternalRecordWorkflowExecutionClosedRequest, +) error { + searchAttributes, err := s.prepareSearchAttributesForDb(request.InternalVisibilityRequestBase) + if err != nil { + return err + } + result, err := s.sqlStore.Db.ReplaceIntoVisibility(ctx, &sqlplugin.VisibilityRow{ + NamespaceID: request.NamespaceID, + WorkflowID: request.WorkflowID, + RunID: request.RunID, + StartTime: request.StartTime, + ExecutionTime: request.ExecutionTime, + WorkflowTypeName: request.WorkflowTypeName, + CloseTime: &request.CloseTime, + Status: int32(request.Status), + HistoryLength: &request.HistoryLength, + Memo: request.Memo.Data, + Encoding: request.Memo.EncodingType.String(), + TaskQueue: request.TaskQueue, + SearchAttributes: searchAttributes, + }) + if err != nil { + return err + } + noRowsAffected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("RecordWorkflowExecutionClosed rowsAffected error: %v", err) + } + if noRowsAffected > 2 { // either adds a new row or deletes old row and adds new row + return fmt.Errorf( + "RecordWorkflowExecutionClosed unexpected numRows (%v) updated", + noRowsAffected, + ) + } + return nil +} + +func (s *VisibilityStore) UpsertWorkflowExecution( + ctx context.Context, + request *store.InternalUpsertWorkflowExecutionRequest, +) error { + searchAttributes, err := s.prepareSearchAttributesForDb(request.InternalVisibilityRequestBase) + if err != nil { + return err + } + result, err := s.sqlStore.Db.ReplaceIntoVisibility(ctx, &sqlplugin.VisibilityRow{ + NamespaceID: request.NamespaceID, + WorkflowID: request.WorkflowID, + RunID: request.RunID, + StartTime: request.StartTime, + ExecutionTime: request.ExecutionTime, + WorkflowTypeName: request.WorkflowTypeName, + Status: int32(request.Status), + Memo: request.Memo.Data, + Encoding: request.Memo.EncodingType.String(), + TaskQueue: request.TaskQueue, + SearchAttributes: searchAttributes, + }) + if err != nil { + return err + } + noRowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + if noRowsAffected > 2 { // either adds a new or deletes old row and adds new row + return fmt.Errorf("UpsertWorkflowExecution unexpected numRows (%v) updates", noRowsAffected) + + } + return nil +} + +func (s *VisibilityStore) ListOpenWorkflowExecutions( + ctx context.Context, + request *manager.ListWorkflowExecutionsRequest, +) (*store.InternalListWorkflowExecutionsResponse, error) { + return s.ListWorkflowExecutions( + ctx, + &manager.ListWorkflowExecutionsRequestV2{ + NamespaceID: request.NamespaceID, + Namespace: request.Namespace, + PageSize: request.PageSize, + NextPageToken: request.NextPageToken, + Query: fmt.Sprintf( + "%s = %d AND %s BETWEEN '%s' AND '%s'", + searchattribute.ExecutionStatus, + int32(enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING), + searchattribute.StartTime, + request.EarliestStartTime.UTC().Format(time.RFC3339Nano), + request.LatestStartTime.UTC().Format(time.RFC3339Nano), + ), + }, + ) +} + +func (s *VisibilityStore) ListClosedWorkflowExecutions( + ctx context.Context, + request *manager.ListWorkflowExecutionsRequest, +) (*store.InternalListWorkflowExecutionsResponse, error) { + return s.ListWorkflowExecutions( + ctx, + &manager.ListWorkflowExecutionsRequestV2{ + NamespaceID: request.NamespaceID, + Namespace: request.Namespace, + PageSize: request.PageSize, + NextPageToken: request.NextPageToken, + Query: fmt.Sprintf( + "%s != %d AND %s BETWEEN '%s' AND '%s'", + searchattribute.ExecutionStatus, + int32(enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING), + searchattribute.CloseTime, + request.EarliestStartTime.UTC().Format(time.RFC3339Nano), + request.LatestStartTime.UTC().Format(time.RFC3339Nano), + ), + }, + ) +} + +func (s *VisibilityStore) ListOpenWorkflowExecutionsByType( + ctx context.Context, + request *manager.ListWorkflowExecutionsByTypeRequest, +) (*store.InternalListWorkflowExecutionsResponse, error) { + return s.ListWorkflowExecutions( + ctx, + &manager.ListWorkflowExecutionsRequestV2{ + NamespaceID: request.NamespaceID, + Namespace: request.Namespace, + PageSize: request.PageSize, + NextPageToken: request.NextPageToken, + Query: fmt.Sprintf( + "%s = %d AND %s = '%s' AND %s BETWEEN '%s' AND '%s'", + searchattribute.ExecutionStatus, + int32(enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING), + searchattribute.WorkflowType, + request.WorkflowTypeName, + searchattribute.StartTime, + request.EarliestStartTime.UTC().Format(time.RFC3339Nano), + request.LatestStartTime.UTC().Format(time.RFC3339Nano), + ), + }, + ) +} + +func (s *VisibilityStore) ListClosedWorkflowExecutionsByType( + ctx context.Context, + request *manager.ListWorkflowExecutionsByTypeRequest, +) (*store.InternalListWorkflowExecutionsResponse, error) { + return s.ListWorkflowExecutions( + ctx, + &manager.ListWorkflowExecutionsRequestV2{ + NamespaceID: request.NamespaceID, + Namespace: request.Namespace, + PageSize: request.PageSize, + NextPageToken: request.NextPageToken, + Query: fmt.Sprintf( + "%s != %d AND %s = '%s' AND %s BETWEEN '%s' AND '%s'", + searchattribute.ExecutionStatus, + int32(enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING), + searchattribute.WorkflowType, + request.WorkflowTypeName, + searchattribute.CloseTime, + request.EarliestStartTime.UTC().Format(time.RFC3339Nano), + request.LatestStartTime.UTC().Format(time.RFC3339Nano), + ), + }, + ) +} + +func (s *VisibilityStore) ListOpenWorkflowExecutionsByWorkflowID( + ctx context.Context, + request *manager.ListWorkflowExecutionsByWorkflowIDRequest, +) (*store.InternalListWorkflowExecutionsResponse, error) { + return s.ListWorkflowExecutions( + ctx, + &manager.ListWorkflowExecutionsRequestV2{ + NamespaceID: request.NamespaceID, + Namespace: request.Namespace, + PageSize: request.PageSize, + NextPageToken: request.NextPageToken, + Query: fmt.Sprintf( + "%s = %d AND %s = '%s' AND %s BETWEEN '%s' AND '%s'", + searchattribute.ExecutionStatus, + int32(enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING), + searchattribute.WorkflowID, + request.WorkflowID, + searchattribute.StartTime, + request.EarliestStartTime.UTC().Format(time.RFC3339Nano), + request.LatestStartTime.UTC().Format(time.RFC3339Nano), + ), + }, + ) +} + +func (s *VisibilityStore) ListClosedWorkflowExecutionsByWorkflowID( + ctx context.Context, + request *manager.ListWorkflowExecutionsByWorkflowIDRequest, +) (*store.InternalListWorkflowExecutionsResponse, error) { + return s.ListWorkflowExecutions( + ctx, + &manager.ListWorkflowExecutionsRequestV2{ + NamespaceID: request.NamespaceID, + Namespace: request.Namespace, + PageSize: request.PageSize, + NextPageToken: request.NextPageToken, + Query: fmt.Sprintf( + "%s != %d AND %s = '%s' AND %s BETWEEN '%s' AND '%s'", + searchattribute.ExecutionStatus, + int32(enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING), + searchattribute.WorkflowID, + request.WorkflowID, + searchattribute.CloseTime, + request.EarliestStartTime.UTC().Format(time.RFC3339Nano), + request.LatestStartTime.UTC().Format(time.RFC3339Nano), + ), + }, + ) +} + +func (s *VisibilityStore) ListClosedWorkflowExecutionsByStatus( + ctx context.Context, + request *manager.ListClosedWorkflowExecutionsByStatusRequest, +) (*store.InternalListWorkflowExecutionsResponse, error) { + return s.ListWorkflowExecutions( + ctx, + &manager.ListWorkflowExecutionsRequestV2{ + NamespaceID: request.NamespaceID, + Namespace: request.Namespace, + PageSize: request.PageSize, + NextPageToken: request.NextPageToken, + Query: fmt.Sprintf( + "%s = %d AND %s BETWEEN '%s' AND '%s'", + searchattribute.ExecutionStatus, + int32(request.Status), + searchattribute.CloseTime, + request.EarliestStartTime.UTC().Format(time.RFC3339Nano), + request.LatestStartTime.UTC().Format(time.RFC3339Nano), + ), + }, + ) +} + +func (s *VisibilityStore) DeleteWorkflowExecution( + ctx context.Context, + request *manager.VisibilityDeleteWorkflowExecutionRequest, +) error { + _, err := s.sqlStore.Db.DeleteFromVisibility(ctx, sqlplugin.VisibilityDeleteFilter{ + NamespaceID: request.NamespaceID.String(), + RunID: request.RunID, + }) + if err != nil { + return serviceerror.NewUnavailable(err.Error()) + } + return nil +} + +func (s *VisibilityStore) ListWorkflowExecutions( + ctx context.Context, + request *manager.ListWorkflowExecutionsRequestV2, +) (*store.InternalListWorkflowExecutionsResponse, error) { + saTypeMap, err := s.searchAttributesProvider.GetSearchAttributes(s.GetIndexName(), false) + if err != nil { + return nil, err + } + + saMapper, err := s.searchAttributesMapperProvider.GetMapper(request.Namespace) + if err != nil { + return nil, err + } + + converter := NewQueryConverter(s.GetName(), request, saTypeMap, saMapper) + selectFilter, err := converter.BuildSelectStmt() + if err != nil { + return nil, err + } + + rows, err := s.sqlStore.Db.SelectFromVisibility(ctx, *selectFilter) + if err != nil { + return nil, serviceerror.NewUnavailable( + fmt.Sprintf("ListWorkflowExecutions operation failed. Select failed: %v", err)) + } + if len(rows) == 0 { + return &store.InternalListWorkflowExecutionsResponse{}, nil + } + + var infos = make([]*store.InternalWorkflowExecutionInfo, len(rows)) + for i, row := range rows { + infos[i], err = s.rowToInfo(&row, request.Namespace) + if err != nil { + return nil, err + } + } + + var nextPageToken []byte + if len(rows) == request.PageSize { + lastRow := rows[len(rows)-1] + closeTime := maxTime + if lastRow.CloseTime != nil { + closeTime = *lastRow.CloseTime + } + nextPageToken, err = serializePageToken(&pageToken{ + CloseTime: closeTime, + StartTime: lastRow.StartTime, + RunID: lastRow.RunID, + }) + if err != nil { + return nil, err + } + } + return &store.InternalListWorkflowExecutionsResponse{ + Executions: infos, + NextPageToken: nextPageToken, + }, nil +} + +func (s *VisibilityStore) ScanWorkflowExecutions( + _ context.Context, + _ *manager.ListWorkflowExecutionsRequestV2, +) (*store.InternalListWorkflowExecutionsResponse, error) { + return nil, store.OperationNotSupportedErr +} + +func (s *VisibilityStore) CountWorkflowExecutions( + _ context.Context, + _ *manager.CountWorkflowExecutionsRequest, +) (*manager.CountWorkflowExecutionsResponse, error) { + return nil, store.OperationNotSupportedErr +} + +func (s *VisibilityStore) GetWorkflowExecution( + ctx context.Context, + request *manager.GetWorkflowExecutionRequest, +) (*store.InternalGetWorkflowExecutionResponse, error) { + row, err := s.sqlStore.Db.GetFromVisibility(ctx, sqlplugin.VisibilityGetFilter{ + NamespaceID: request.NamespaceID.String(), + RunID: request.RunID, + }) + if err != nil { + return nil, serviceerror.NewUnavailable( + fmt.Sprintf("GetWorkflowExecution operation failed. Select failed: %v", err)) + } + info, err := s.rowToInfo(row, request.Namespace) + if err != nil { + return nil, err + } + return &store.InternalGetWorkflowExecutionResponse{ + Execution: info, + }, nil +} + +func (s *VisibilityStore) prepareSearchAttributesForDb( + request *store.InternalVisibilityRequestBase, +) (*sqlplugin.VisibilitySearchAttributes, error) { + if request.SearchAttributes == nil { + return nil, nil + } + + saTypeMap, err := s.searchAttributesProvider.GetSearchAttributes( + s.GetIndexName(), + false, + ) + if err != nil { + return nil, serviceerror.NewUnavailable( + fmt.Sprintf("Unable to read search attributes types: %v", err)) + } + + var searchAttributes sqlplugin.VisibilitySearchAttributes + searchAttributes, err = searchattribute.Decode(request.SearchAttributes, &saTypeMap) + if err != nil { + return nil, err + } + + for name, value := range searchAttributes { + tp, err := saTypeMap.GetType(name) + if err != nil { + return nil, err + } + if tp == enumspb.INDEXED_VALUE_TYPE_DATETIME { + if dt, ok := value.(time.Time); ok { + searchAttributes[name] = dt.Format(time.RFC3339Nano) + } + } + } + return &searchAttributes, nil +} + +func (s *VisibilityStore) rowToInfo( + row *sqlplugin.VisibilityRow, + nsName namespace.Name, +) (*store.InternalWorkflowExecutionInfo, error) { + if row.ExecutionTime.UnixNano() == 0 { + row.ExecutionTime = row.StartTime + } + info := &store.InternalWorkflowExecutionInfo{ + WorkflowID: row.WorkflowID, + RunID: row.RunID, + TypeName: row.WorkflowTypeName, + StartTime: row.StartTime, + ExecutionTime: row.ExecutionTime, + Memo: persistence.NewDataBlob(row.Memo, row.Encoding), + Status: enumspb.WorkflowExecutionStatus(row.Status), + TaskQueue: row.TaskQueue, + } + if row.SearchAttributes != nil && len(*row.SearchAttributes) > 0 { + searchAttributes, err := s.processRowSearchAttributes(*row.SearchAttributes, nsName) + if err != nil { + return nil, err + } + info.SearchAttributes = searchAttributes + } + if row.CloseTime != nil { + info.CloseTime = *row.CloseTime + } + if row.HistoryLength != nil { + info.HistoryLength = *row.HistoryLength + } + return info, nil +} + +func (s *VisibilityStore) processRowSearchAttributes( + rowSearchAttributes sqlplugin.VisibilitySearchAttributes, + nsName namespace.Name, +) (*common.SearchAttributes, error) { + saTypeMap, err := s.searchAttributesProvider.GetSearchAttributes( + s.GetIndexName(), + false, + ) + if err != nil { + return nil, serviceerror.NewUnavailable( + fmt.Sprintf("Unable to read search attributes types: %v", err)) + } + // In SQLite, keyword list can return a string when there's only one element. + // This changes it into a slice. + for name, value := range rowSearchAttributes { + tp, err := saTypeMap.GetType(name) + if err != nil { + return nil, err + } + if tp == enumspb.INDEXED_VALUE_TYPE_KEYWORD_LIST { + switch v := value.(type) { + case []string: + // no-op + case string: + (rowSearchAttributes)[name] = []string{v} + default: + return nil, serviceerror.NewInternal( + fmt.Sprintf("Unexpected data type for keyword list: %T (expected list of strings)", v), + ) + } + } + } + searchAttributes, err := searchattribute.Encode(rowSearchAttributes, &saTypeMap) + if err != nil { + return nil, err + } + aliasedSas, err := searchattribute.AliasFields( + s.searchAttributesMapperProvider, + searchAttributes, + nsName.String(), + ) + if err != nil { + return nil, err + } + if aliasedSas != nil { + searchAttributes = aliasedSas + } + return searchAttributes, nil +} diff --git a/common/searchattribute/defs.go b/common/searchattribute/defs.go index 486322d17ff..04d6b9fa4ef 100644 --- a/common/searchattribute/defs.go +++ b/common/searchattribute/defs.go @@ -104,6 +104,21 @@ var ( VisibilityTaskKey: {}, } + sqlDbSystemNameToColName = map[string]string{ + NamespaceID: "namespace_id", + WorkflowID: "workflow_id", + RunID: "run_id", + WorkflowType: "workflow_type_name", + StartTime: "start_time", + ExecutionTime: "execution_time", + CloseTime: "close_time", + ExecutionStatus: "status", + TaskQueue: "task_queue", + HistoryLength: "history_length", + Memo: "memo", + MemoEncoding: "encoding", + } + sqlDbCustomSearchAttributes = map[string]enumspb.IndexedValueType{ "Bool01": enumspb.INDEXED_VALUE_TYPE_BOOL, "Bool02": enumspb.INDEXED_VALUE_TYPE_BOOL, @@ -161,6 +176,15 @@ func IsMappable(name string) bool { return true } +// GetSqlDbColName maps system and reserved search attributes to column names for SQL tables. +// If the input is not a system or reserved search attribute, then it returns the input. +func GetSqlDbColName(name string) string { + if fieldName, ok := sqlDbSystemNameToColName[name]; ok { + return fieldName + } + return name +} + func GetSqlDbIndexSearchAttributes() *persistencespb.IndexSearchAttributes { return &persistencespb.IndexSearchAttributes{ CustomSearchAttributes: sqlDbCustomSearchAttributes,