Skip to content

Commit

Permalink
lock TagSettings structure when modified (go-gorm#1796)
Browse files Browse the repository at this point in the history
The map is modified in different places in the code which results in race conditions
on execution.
This commit locks the map with read-write lock when it is modified
  • Loading branch information
posener authored and jinzhu committed Sep 9, 2018
1 parent 282f11a commit 123d4f5
Show file tree
Hide file tree
Showing 12 changed files with 90 additions and 61 deletions.
2 changes: 1 addition & 1 deletion callback_query_preload.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func autoPreload(scope *Scope) {
continue
}

if val, ok := field.TagSettings["PRELOAD"]; ok {
if val, ok := field.TagSettingsGet("PRELOAD"); ok {
if preload, err := strconv.ParseBool(val); err != nil {
scope.Err(errors.New("invalid preload option"))
return
Expand Down
8 changes: 4 additions & 4 deletions callback_save.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,27 +35,27 @@ func saveAssociationCheck(scope *Scope, field *Field) (autoUpdate bool, autoCrea
autoUpdate = checkTruth(value)
autoCreate = autoUpdate
saveReference = autoUpdate
} else if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; ok {
} else if value, ok := field.TagSettingsGet("SAVE_ASSOCIATIONS"); ok {
autoUpdate = checkTruth(value)
autoCreate = autoUpdate
saveReference = autoUpdate
}

if value, ok := scope.Get("gorm:association_autoupdate"); ok {
autoUpdate = checkTruth(value)
} else if value, ok := field.TagSettings["ASSOCIATION_AUTOUPDATE"]; ok {
} else if value, ok := field.TagSettingsGet("ASSOCIATION_AUTOUPDATE"); ok {
autoUpdate = checkTruth(value)
}

if value, ok := scope.Get("gorm:association_autocreate"); ok {
autoCreate = checkTruth(value)
} else if value, ok := field.TagSettings["ASSOCIATION_AUTOCREATE"]; ok {
} else if value, ok := field.TagSettingsGet("ASSOCIATION_AUTOCREATE"); ok {
autoCreate = checkTruth(value)
}

if value, ok := scope.Get("gorm:association_save_reference"); ok {
saveReference = checkTruth(value)
} else if value, ok := field.TagSettings["ASSOCIATION_SAVE_REFERENCE"]; ok {
} else if value, ok := field.TagSettingsGet("ASSOCIATION_SAVE_REFERENCE"); ok {
saveReference = checkTruth(value)
}
}
Expand Down
10 changes: 6 additions & 4 deletions dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fiel
// Get redirected field type
var (
reflectType = field.Struct.Type
dataType = field.TagSettings["TYPE"]
dataType, _ = field.TagSettingsGet("TYPE")
)

for reflectType.Kind() == reflect.Ptr {
Expand Down Expand Up @@ -112,15 +112,17 @@ var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fiel
}

// Default Size
if num, ok := field.TagSettings["SIZE"]; ok {
if num, ok := field.TagSettingsGet("SIZE"); ok {
size, _ = strconv.Atoi(num)
} else {
size = 255
}

// Default type from tag setting
additionalType = field.TagSettings["NOT NULL"] + " " + field.TagSettings["UNIQUE"]
if value, ok := field.TagSettings["DEFAULT"]; ok {
notNull, _ := field.TagSettingsGet("NOT NULL")
unique, _ := field.TagSettingsGet("UNIQUE")
additionalType = notNull + " " + unique
if value, ok := field.TagSettingsGet("DEFAULT"); ok {
additionalType = additionalType + " DEFAULT " + value
}

Expand Down
2 changes: 1 addition & 1 deletion dialect_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func (commonDialect) Quote(key string) string {
}

func (s *commonDialect) fieldCanAutoIncrement(field *StructField) bool {
if value, ok := field.TagSettings["AUTO_INCREMENT"]; ok {
if value, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok {
return strings.ToLower(value) != "false"
}
return field.IsPrimaryKey
Expand Down
22 changes: 11 additions & 11 deletions dialect_mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ func (s *mysql) DataTypeOf(field *StructField) string {

// MySQL allows only one auto increment column per table, and it must
// be a KEY column.
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok {
if _, ok = field.TagSettings["INDEX"]; !ok && !field.IsPrimaryKey {
delete(field.TagSettings, "AUTO_INCREMENT")
if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok {
if _, ok = field.TagSettingsGet("INDEX"); !ok && !field.IsPrimaryKey {
field.TagSettingsDelete("AUTO_INCREMENT")
}
}

Expand All @@ -45,42 +45,42 @@ func (s *mysql) DataTypeOf(field *StructField) string {
sqlType = "boolean"
case reflect.Int8:
if s.fieldCanAutoIncrement(field) {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
sqlType = "tinyint AUTO_INCREMENT"
} else {
sqlType = "tinyint"
}
case reflect.Int, reflect.Int16, reflect.Int32:
if s.fieldCanAutoIncrement(field) {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
sqlType = "int AUTO_INCREMENT"
} else {
sqlType = "int"
}
case reflect.Uint8:
if s.fieldCanAutoIncrement(field) {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
sqlType = "tinyint unsigned AUTO_INCREMENT"
} else {
sqlType = "tinyint unsigned"
}
case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if s.fieldCanAutoIncrement(field) {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
sqlType = "int unsigned AUTO_INCREMENT"
} else {
sqlType = "int unsigned"
}
case reflect.Int64:
if s.fieldCanAutoIncrement(field) {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
sqlType = "bigint AUTO_INCREMENT"
} else {
sqlType = "bigint"
}
case reflect.Uint64:
if s.fieldCanAutoIncrement(field) {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
sqlType = "bigint unsigned AUTO_INCREMENT"
} else {
sqlType = "bigint unsigned"
Expand All @@ -96,11 +96,11 @@ func (s *mysql) DataTypeOf(field *StructField) string {
case reflect.Struct:
if _, ok := dataValue.Interface().(time.Time); ok {
precision := ""
if p, ok := field.TagSettings["PRECISION"]; ok {
if p, ok := field.TagSettingsGet("PRECISION"); ok {
precision = fmt.Sprintf("(%s)", p)
}

if _, ok := field.TagSettings["NOT NULL"]; ok {
if _, ok := field.TagSettingsGet("NOT NULL"); ok {
sqlType = fmt.Sprintf("timestamp%v", precision)
} else {
sqlType = fmt.Sprintf("timestamp%v NULL", precision)
Expand Down
6 changes: 3 additions & 3 deletions dialect_postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,22 @@ func (s *postgres) DataTypeOf(field *StructField) string {
sqlType = "boolean"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uintptr:
if s.fieldCanAutoIncrement(field) {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
sqlType = "serial"
} else {
sqlType = "integer"
}
case reflect.Int64, reflect.Uint32, reflect.Uint64:
if s.fieldCanAutoIncrement(field) {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
sqlType = "bigserial"
} else {
sqlType = "bigint"
}
case reflect.Float32, reflect.Float64:
sqlType = "numeric"
case reflect.String:
if _, ok := field.TagSettings["SIZE"]; !ok {
if _, ok := field.TagSettingsGet("SIZE"); !ok {
size = 0 // if SIZE haven't been set, use `text` as the default type, as there are no performance different
}

Expand Down
4 changes: 2 additions & 2 deletions dialect_sqlite3.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@ func (s *sqlite3) DataTypeOf(field *StructField) string {
sqlType = "bool"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if s.fieldCanAutoIncrement(field) {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
sqlType = "integer primary key autoincrement"
} else {
sqlType = "integer"
}
case reflect.Int64, reflect.Uint64:
if s.fieldCanAutoIncrement(field) {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
sqlType = "integer primary key autoincrement"
} else {
sqlType = "bigint"
Expand Down
8 changes: 4 additions & 4 deletions dialects/mssql/mssql.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import (
func setIdentityInsert(scope *gorm.Scope) {
if scope.Dialect().GetName() == "mssql" {
for _, field := range scope.PrimaryFields() {
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok && !field.IsBlank {
if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok && !field.IsBlank {
scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v ON", scope.TableName()))
scope.InstanceSet("mssql:identity_insert_on", true)
}
Expand Down Expand Up @@ -70,14 +70,14 @@ func (s *mssql) DataTypeOf(field *gorm.StructField) string {
sqlType = "bit"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if s.fieldCanAutoIncrement(field) {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
sqlType = "int IDENTITY(1,1)"
} else {
sqlType = "int"
}
case reflect.Int64, reflect.Uint64:
if s.fieldCanAutoIncrement(field) {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
sqlType = "bigint IDENTITY(1,1)"
} else {
sqlType = "bigint"
Expand Down Expand Up @@ -116,7 +116,7 @@ func (s *mssql) DataTypeOf(field *gorm.StructField) string {
}

func (s mssql) fieldCanAutoIncrement(field *gorm.StructField) bool {
if value, ok := field.TagSettings["AUTO_INCREMENT"]; ok {
if value, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok {
return value != "FALSE"
}
return field.IsPrimaryKey
Expand Down
2 changes: 1 addition & 1 deletion field_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func TestCalculateField(t *testing.T) {

if field, ok := scope.FieldByName("embedded_name"); !ok {
t.Errorf("should find embedded field")
} else if _, ok := field.TagSettings["NOT NULL"]; !ok {
} else if _, ok := field.TagSettingsGet("NOT NULL"); !ok {
t.Errorf("should find embedded field's tag settings")
}
}
2 changes: 1 addition & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ func (s *DB) SetJoinTableHandler(source interface{}, column string, handler Join
scope := s.NewScope(source)
for _, field := range scope.GetModelStruct().StructFields {
if field.Name == column || field.DBName == column {
if many2many := field.TagSettings["MANY2MANY"]; many2many != "" {
if many2many, _ := field.TagSettingsGet("MANY2MANY"); many2many != "" {
source := (&Scope{Value: source}).GetModelStruct().ModelType
destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType
handler.Setup(field.Relationship, many2many, source, destination)
Expand Down
Loading

0 comments on commit 123d4f5

Please sign in to comment.