diff --git a/types/field_type.go b/types/field_type.go index d33d83a2d2c2d..404666744fafa 100644 --- a/types/field_type.go +++ b/types/field_type.go @@ -53,6 +53,7 @@ func AggFieldType(tps []*FieldType) *FieldType { } mtp := MergeFieldType(currType.Tp, t.Tp) currType.Tp = mtp + currType.Flag = mergeTypeFlag(currType.Flag, t.Flag) } return &currType @@ -283,6 +284,13 @@ func MergeFieldType(a byte, b byte) byte { return fieldTypeMergeRules[ia][ib] } +// mergeTypeFlag merges two MySQL type flag to a new one +// currently only NotNullFlag is checked +// todo more flag need to be checked, for example: UnsignedFlag +func mergeTypeFlag(a, b uint) uint { + return a & (b&mysql.NotNullFlag | ^mysql.NotNullFlag) +} + func getFieldTypeIndex(tp byte) int { itp := int(tp) if itp < fieldTypeTearFrom { diff --git a/types/field_type_test.go b/types/field_type_test.go index d7d9185036211..3284350914258 100644 --- a/types/field_type_test.go +++ b/types/field_type_test.go @@ -303,6 +303,32 @@ func (s *testFieldTypeSuite) TestAggFieldType(c *C) { } } } +func (s *testFieldTypeSuite) TestAggFieldTypeForTypeFlag(c *C) { + types := []*FieldType{ + NewFieldType(mysql.TypeLonglong), + NewFieldType(mysql.TypeLonglong), + } + + aggTp := AggFieldType(types) + c.Assert(aggTp.Tp, Equals, mysql.TypeLonglong) + c.Assert(aggTp.Flag, Equals, uint(0)) + + types[0].Flag = mysql.NotNullFlag + aggTp = AggFieldType(types) + c.Assert(aggTp.Tp, Equals, mysql.TypeLonglong) + c.Assert(aggTp.Flag, Equals, uint(0)) + + types[0].Flag = 0 + types[1].Flag = mysql.NotNullFlag + aggTp = AggFieldType(types) + c.Assert(aggTp.Tp, Equals, mysql.TypeLonglong) + c.Assert(aggTp.Flag, Equals, uint(0)) + + types[0].Flag = mysql.NotNullFlag + aggTp = AggFieldType(types) + c.Assert(aggTp.Tp, Equals, mysql.TypeLonglong) + c.Assert(aggTp.Flag, Equals, mysql.NotNullFlag) +} func (s *testFieldTypeSuite) TestAggregateEvalType(c *C) { defer testleak.AfterTest(c)()