Skip to content

Commit

Permalink
Strict validation when using SQL DB for visibility (#3905)
Browse files Browse the repository at this point in the history
  • Loading branch information
rodrigozhou authored Feb 7, 2023
1 parent 52c3a9e commit a840d75
Show file tree
Hide file tree
Showing 17 changed files with 376 additions and 193 deletions.
17 changes: 17 additions & 0 deletions common/persistence/visibility/defs.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@

package visibility

import (
"go.temporal.io/server/common/persistence/sql/sqlplugin/mysql"
"go.temporal.io/server/common/persistence/sql/sqlplugin/postgresql"
"go.temporal.io/server/common/persistence/sql/sqlplugin/sqlite"
)

const (
// AdvancedVisibilityWritingModeOff means do not write to advanced visibility store
AdvancedVisibilityWritingModeOff = "off"
Expand All @@ -40,3 +46,14 @@ func DefaultAdvancedVisibilityWritingMode(advancedVisibilityConfigExist bool) st
}
return AdvancedVisibilityWritingModeOff
}

func AllowListForValidation(pluginName string) bool {
switch pluginName {
case mysql.PluginNameV8, postgresql.PluginNameV12, sqlite.PluginName:
// Advanced visibility with SQL DB don't support list of values
return false
default:
// Otherwise, enable for backward compatibility.
return true
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -874,7 +874,7 @@ func (s *visibilityStore) generateESDoc(request *store.InternalVisibilityRequest
return nil, serviceerror.NewUnavailable(fmt.Sprintf("Unable to read search attribute types: %v", err))
}

searchAttributes, err := searchattribute.Decode(request.SearchAttributes, &typeMap)
searchAttributes, err := searchattribute.Decode(request.SearchAttributes, &typeMap, true)
if err != nil {
s.metricsHandler.Counter(metrics.ElasticsearchDocumentGenerateFailuresCount.GetMetricName()).Record(1)
return nil, serviceerror.NewInternal(fmt.Sprintf("Unable to decode search attributes: %v", err))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -920,7 +920,7 @@ func (s *ESVisibilitySuite) TestParseESDoc_SearchAttributes() {
info, err := s.visibilityStore.parseESDoc("", docSource, searchattribute.TestNameTypeMap, testNamespace)
s.NoError(err)
s.NotNil(info)
customSearchAttributes, err := searchattribute.Decode(info.SearchAttributes, &searchattribute.TestNameTypeMap)
customSearchAttributes, err := searchattribute.Decode(info.SearchAttributes, &searchattribute.TestNameTypeMap, true)
s.NoError(err)

s.Len(customSearchAttributes, 7)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ func (s *VisibilityStore) prepareSearchAttributesForDb(
}

var searchAttributes sqlplugin.VisibilitySearchAttributes
searchAttributes, err = searchattribute.Decode(request.SearchAttributes, &saTypeMap)
searchAttributes, err = searchattribute.Decode(request.SearchAttributes, &saTypeMap, false)
if err != nil {
return nil, err
}
Expand Down
8 changes: 6 additions & 2 deletions common/searchattribute/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,11 @@ func Encode(searchAttributes map[string]interface{}, typeMap *NameTypeMap) (*com
// 1. type from typeMap,
// 2. if typeMap is nil, type from MetadataType field is used.
// In case of error, it will continue to next search attribute and return last error.
func Decode(searchAttributes *commonpb.SearchAttributes, typeMap *NameTypeMap) (map[string]interface{}, error) {
func Decode(
searchAttributes *commonpb.SearchAttributes,
typeMap *NameTypeMap,
allowList bool,
) (map[string]interface{}, error) {
if len(searchAttributes.GetIndexedFields()) == 0 {
return nil, nil
}
Expand All @@ -84,7 +88,7 @@ func Decode(searchAttributes *commonpb.SearchAttributes, typeMap *NameTypeMap) (
}
}

searchAttributeValue, err := DecodeValue(saPayload, saType)
searchAttributeValue, err := DecodeValue(saPayload, saType, allowList)
if err != nil {
lastErr = err
result[saName] = nil
Expand Down
22 changes: 13 additions & 9 deletions common/searchattribute/encode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ func Test_Decode_Success(t *testing.T) {
}, typeMap)
assert.NoError(err)

vals, err := Decode(sa, typeMap)
vals, err := Decode(sa, typeMap, true)
assert.NoError(err)
assert.Len(vals, 6)
assert.Equal("val1", vals["key1"])
Expand All @@ -154,7 +154,7 @@ func Test_Decode_Success(t *testing.T) {
delete(sa.IndexedFields["key5"].Metadata, "type")
delete(sa.IndexedFields["key6"].Metadata, "type")

vals, err = Decode(sa, typeMap)
vals, err = Decode(sa, typeMap, true)
assert.NoError(err)
assert.Len(vals, 6)
assert.Equal("val1", vals["key1"])
Expand Down Expand Up @@ -185,7 +185,7 @@ func Test_Decode_NilMap(t *testing.T) {
}, typeMap)
assert.NoError(err)

vals, err := Decode(sa, nil)
vals, err := Decode(sa, nil, true)
assert.NoError(err)
assert.Len(sa.IndexedFields, 6)
assert.Equal("val1", vals["key1"])
Expand All @@ -211,11 +211,15 @@ func Test_Decode_Error(t *testing.T) {
}, typeMap)
assert.NoError(err)

vals, err := Decode(sa, &NameTypeMap{customSearchAttributes: map[string]enumspb.IndexedValueType{
"key1": enumspb.INDEXED_VALUE_TYPE_TEXT,
"key4": enumspb.INDEXED_VALUE_TYPE_INT,
"key3": enumspb.INDEXED_VALUE_TYPE_BOOL,
}})
vals, err := Decode(
sa,
&NameTypeMap{customSearchAttributes: map[string]enumspb.IndexedValueType{
"key1": enumspb.INDEXED_VALUE_TYPE_TEXT,
"key4": enumspb.INDEXED_VALUE_TYPE_INT,
"key3": enumspb.INDEXED_VALUE_TYPE_BOOL,
}},
true,
)
assert.Error(err)
assert.True(errors.Is(err, ErrInvalidName))
assert.Len(sa.IndexedFields, 3)
Expand All @@ -227,7 +231,7 @@ func Test_Decode_Error(t *testing.T) {
delete(sa.IndexedFields["key2"].Metadata, "type")
delete(sa.IndexedFields["key3"].Metadata, "type")

vals, err = Decode(sa, nil)
vals, err = Decode(sa, nil, true)
assert.Error(err)
assert.True(errors.Is(err, ErrInvalidType))
assert.Len(vals, 3)
Expand Down
126 changes: 52 additions & 74 deletions common/searchattribute/encode_value.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,44 @@ func EncodeValue(val interface{}, t enumspb.IndexedValueType) (*commonpb.Payload
// DecodeValue decodes search attribute value from Payload using (in order):
// 1. passed type t.
// 2. type from MetadataType field, if t is not specified.
func DecodeValue(value *commonpb.Payload, t enumspb.IndexedValueType) (interface{}, error) {
// allowList allows list of values when it's not keyword list type.
func DecodeValue(
value *commonpb.Payload,
t enumspb.IndexedValueType,
allowList bool,
) (any, error) {
if t == enumspb.INDEXED_VALUE_TYPE_UNSPECIFIED {
t = enumspb.IndexedValueType(enumspb.IndexedValueType_value[string(value.Metadata[MetadataType])])
t = enumspb.IndexedValueType(
enumspb.IndexedValueType_value[string(value.Metadata[MetadataType])],
)
}

// Here are similar code sections for all types.
switch t {
case enumspb.INDEXED_VALUE_TYPE_BOOL:
return decodeValueTyped[bool](value, allowList)
case enumspb.INDEXED_VALUE_TYPE_DATETIME:
return decodeValueTyped[time.Time](value, allowList)
case enumspb.INDEXED_VALUE_TYPE_DOUBLE:
return decodeValueTyped[float64](value, allowList)
case enumspb.INDEXED_VALUE_TYPE_INT:
return decodeValueTyped[int64](value, allowList)
case enumspb.INDEXED_VALUE_TYPE_KEYWORD:
return decodeValueTyped[string](value, allowList)
case enumspb.INDEXED_VALUE_TYPE_TEXT:
return decodeValueTyped[string](value, allowList)
case enumspb.INDEXED_VALUE_TYPE_KEYWORD_LIST:
return decodeValueTyped[[]string](value, false)
default:
return nil, fmt.Errorf("%w: %v", ErrInvalidType, t)
}
}

// decodeValueTyped tries to decode to the given type.
// If the input is a list and allowList is false, then it will return only the first element.
// If the input is a list and allowList is true, then it will return the decoded list.
//
//nolint:revive // allowList is a control flag
func decodeValueTyped[T any](value *commonpb.Payload, allowList bool) (any, error) {
// At first, it tries to decode to pointer of actual type (i.e. `*string` for `string`).
// This is to ensure that `nil` values are decoded back as `nil` using `NilPayloadConverter`.
// If value is not `nil` but some value of expected type, the code relies on the fact that
Expand All @@ -62,82 +94,28 @@ func DecodeValue(value *commonpb.Payload, t enumspb.IndexedValueType) (interface
// If decoding to pointer type fails, it tries to decode to array of the same type because
// search attributes support polymorphism: field of specific type may also have an array of that type.
// If resulting slice has zero length, it gets substitute with `nil` to treat nils and empty slices equally.
// If allowList is true, it returns the list as it is. If allowList is false and the list has
// only one element, then return it. Otherwise, return an error.
// If search attribute value is `nil`, it means that search attribute needs to be removed from the document.

switch t {
case enumspb.INDEXED_VALUE_TYPE_TEXT,
enumspb.INDEXED_VALUE_TYPE_KEYWORD,
enumspb.INDEXED_VALUE_TYPE_KEYWORD_LIST:
var val *string
if err := payload.Decode(value, &val); err != nil {
var listVal []string
err = payload.Decode(value, &listVal)
if len(listVal) == 0 {
return nil, err
}
return listVal, err
}
if val == nil {
return nil, nil
var val *T
if err := payload.Decode(value, &val); err != nil {
var listVal []T
if err := payload.Decode(value, &listVal); err != nil {
return nil, err
}
return *val, nil
case enumspb.INDEXED_VALUE_TYPE_INT:
var val *int64
if err := payload.Decode(value, &val); err != nil {
var listVal []int64
err = payload.Decode(value, &listVal)
if len(listVal) == 0 {
return nil, err
}
return listVal, err
}
if val == nil {
return nil, nil
}
return *val, nil
case enumspb.INDEXED_VALUE_TYPE_DOUBLE:
var val *float64
if err := payload.Decode(value, &val); err != nil {
var listVal []float64
err = payload.Decode(value, &listVal)
if len(listVal) == 0 {
return nil, err
}
return listVal, err
}
if val == nil {
if len(listVal) == 0 {
return nil, nil
}
return *val, nil
case enumspb.INDEXED_VALUE_TYPE_BOOL:
var val *bool
if err := payload.Decode(value, &val); err != nil {
var listVal []bool
err = payload.Decode(value, &listVal)
if len(listVal) == 0 {
return nil, err
}
return listVal, err
if allowList {
return listVal, nil
}
if val == nil {
return nil, nil
if len(listVal) == 1 {
return listVal[0], nil
}
return *val, nil
case enumspb.INDEXED_VALUE_TYPE_DATETIME:
var val *time.Time
if err := payload.Decode(value, &val); err != nil {
var listVal []time.Time
err = payload.Decode(value, &listVal)
if len(listVal) == 0 {
return nil, err
}
return listVal, err
}
if val == nil {
return nil, nil
}
return *val, nil
default:
return nil, fmt.Errorf("%w: %v", ErrInvalidType, t)
return nil, fmt.Errorf("list of values not allowed for type %T", listVal[0])
}
if val == nil {
return nil, nil
}
return *val, nil
}
Loading

0 comments on commit a840d75

Please sign in to comment.