diff --git a/expression/builtin_encryption.go b/expression/builtin_encryption.go index f12bd2832abf8..94eaef8a84574 100644 --- a/expression/builtin_encryption.go +++ b/expression/builtin_encryption.go @@ -16,6 +16,7 @@ package expression import ( "bytes" "compress/zlib" + "crypto/aes" "crypto/md5" "crypto/rand" "crypto/sha1" @@ -25,9 +26,11 @@ import ( "fmt" "hash" "io" + "strings" "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/auth" "github.com/pingcap/tidb/util/chunk" @@ -57,7 +60,9 @@ var ( var ( _ builtinFunc = &builtinAesDecryptSig{} + _ builtinFunc = &builtinAesDecryptIVSig{} _ builtinFunc = &builtinAesEncryptSig{} + _ builtinFunc = &builtinAesEncryptIVSig{} _ builtinFunc = &builtinCompressSig{} _ builtinFunc = &builtinMD5Sig{} _ builtinFunc = &builtinPasswordSig{} @@ -68,10 +73,27 @@ var ( _ builtinFunc = &builtinUncompressedLengthSig{} ) -// TODO: support other mode -const ( - aes128ecbBlobkSize = 16 -) +// ivSize indicates the initialization vector supplied to aes_decrypt +const ivSize = aes.BlockSize + +// aesModeAttr indicates that the key length and iv attribute for specific block_encryption_mode. +// keySize is the key length in bits and mode is the encryption mode. +// ivRequired indicates that initialization vector is required or not. +type aesModeAttr struct { + modeName string + keySize int + ivRequired bool +} + +var aesModes = map[string]*aesModeAttr{ + //TODO support more modes, permitted mode values are: ECB, CBC, CFB1, CFB8, CFB128, OFB + "aes-128-ecb": {"ecb", 16, false}, + "aes-192-ecb": {"ecb", 24, false}, + "aes-256-ecb": {"ecb", 32, false}, + "aes-128-cbc": {"cbc", 16, true}, + "aes-192-cbc": {"cbc", 24, true}, + "aes-256-cbc": {"cbc", 32, true}, +} type aesDecryptFunctionClass struct { baseFunctionClass @@ -81,20 +103,37 @@ func (c *aesDecryptFunctionClass) getFunction(ctx sessionctx.Context, args []Exp if err := c.verifyArgs(args); err != nil { return nil, errors.Trace(c.verifyArgs(args)) } - bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETString, types.ETString, types.ETString) + argTps := make([]types.EvalType, 0, len(args)) + for range args { + argTps = append(argTps, types.ETString) + } + bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETString, argTps...) bf.tp.Flen = args[0].GetType().Flen // At most. types.SetBinChsClnFlag(bf.tp) - sig := &builtinAesDecryptSig{bf} - return sig, nil + + blockMode, _ := ctx.GetSessionVars().GetSystemVar(variable.BlockEncryptionMode) + mode, exists := aesModes[strings.ToLower(blockMode)] + if !exists { + return nil, errors.Errorf("unsupported block encryption mode - %v", blockMode) + } + if mode.ivRequired { + if len(args) != 3 { + return nil, ErrIncorrectParameterCount.GenWithStackByArgs("aes_decrypt") + } + return &builtinAesDecryptIVSig{bf, mode}, nil + } + return &builtinAesDecryptSig{bf, mode}, nil } type builtinAesDecryptSig struct { baseBuiltinFunc + *aesModeAttr } func (b *builtinAesDecryptSig) Clone() builtinFunc { newSig := &builtinAesDecryptSig{} newSig.cloneFrom(&b.baseBuiltinFunc) + newSig.aesModeAttr = b.aesModeAttr return newSig } @@ -106,15 +145,73 @@ func (b *builtinAesDecryptSig) evalString(row chunk.Row) (string, bool, error) { if isNull || err != nil { return "", true, errors.Trace(err) } + keyStr, isNull, err := b.args[1].EvalString(b.ctx, row) + if isNull || err != nil { + return "", true, errors.Trace(err) + } + if !b.ivRequired && len(b.args) == 3 { + // For modes that do not require init_vector, it is ignored and a warning is generated if it is specified. + b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnOptionIgnored.GenWithStackByArgs("IV")) + } + + key := encrypt.DeriveKeyMySQL([]byte(keyStr), b.keySize) + var plainText []byte + switch b.modeName { + case "ecb": + plainText, err = encrypt.AESDecryptWithECB([]byte(cryptStr), key) + default: + return "", true, errors.Errorf("unsupported block encryption mode - %v", b.modeName) + } + if err != nil { + return "", true, nil + } + return string(plainText), false, nil +} + +type builtinAesDecryptIVSig struct { + baseBuiltinFunc + *aesModeAttr +} + +func (b *builtinAesDecryptIVSig) Clone() builtinFunc { + newSig := &builtinAesDecryptIVSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + newSig.aesModeAttr = b.aesModeAttr + return newSig +} + +// evalString evals AES_DECRYPT(crypt_str, key_key, iv). +// See https://dev.mysql.com/doc/refman/5.7/en/encryption-functions.html#function_aes-decrypt +func (b *builtinAesDecryptIVSig) evalString(row chunk.Row) (string, bool, error) { + // According to doc: If either function argument is NULL, the function returns NULL. + cryptStr, isNull, err := b.args[0].EvalString(b.ctx, row) + if isNull || err != nil { + return "", true, errors.Trace(err) + } keyStr, isNull, err := b.args[1].EvalString(b.ctx, row) if isNull || err != nil { return "", true, errors.Trace(err) } - // TODO: Support other modes. - key := encrypt.DeriveKeyMySQL([]byte(keyStr), aes128ecbBlobkSize) - plainText, err := encrypt.AESDecryptWithECB([]byte(cryptStr), key) + iv, isNull, err := b.args[2].EvalString(b.ctx, row) + if isNull || err != nil { + return "", true, errors.Trace(err) + } + if len(iv) < aes.BlockSize { + return "", true, errIncorrectArgs.GenWithStack("The initialization vector supplied to aes_decrypt is too short. Must be at least %d bytes long", aes.BlockSize) + } + // init_vector must be 16 bytes or longer (bytes in excess of 16 are ignored) + iv = iv[0:aes.BlockSize] + + key := encrypt.DeriveKeyMySQL([]byte(keyStr), b.keySize) + var plainText []byte + switch b.modeName { + case "cbc": + plainText, err = encrypt.AESDecryptWithCBC([]byte(cryptStr), key, []byte(iv)) + default: + return "", true, errors.Errorf("unsupported block encryption mode - %v", b.modeName) + } if err != nil { return "", true, nil } @@ -129,20 +226,37 @@ func (c *aesEncryptFunctionClass) getFunction(ctx sessionctx.Context, args []Exp if err := c.verifyArgs(args); err != nil { return nil, errors.Trace(c.verifyArgs(args)) } - bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETString, types.ETString, types.ETString) - bf.tp.Flen = aes128ecbBlobkSize * (args[0].GetType().Flen/aes128ecbBlobkSize + 1) // At most. + argTps := make([]types.EvalType, 0, len(args)) + for range args { + argTps = append(argTps, types.ETString) + } + bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETString, argTps...) + bf.tp.Flen = aes.BlockSize * (args[0].GetType().Flen/aes.BlockSize + 1) // At most. types.SetBinChsClnFlag(bf.tp) - sig := &builtinAesEncryptSig{bf} - return sig, nil + + blockMode, _ := ctx.GetSessionVars().GetSystemVar(variable.BlockEncryptionMode) + mode, exists := aesModes[strings.ToLower(blockMode)] + if !exists { + return nil, errors.Errorf("unsupported block encryption mode - %v", blockMode) + } + if mode.ivRequired { + if len(args) != 3 { + return nil, ErrIncorrectParameterCount.GenWithStackByArgs("aes_encrypt") + } + return &builtinAesEncryptIVSig{bf, mode}, nil + } + return &builtinAesEncryptSig{bf, mode}, nil } type builtinAesEncryptSig struct { baseBuiltinFunc + *aesModeAttr } func (b *builtinAesEncryptSig) Clone() builtinFunc { newSig := &builtinAesEncryptSig{} newSig.cloneFrom(&b.baseBuiltinFunc) + newSig.aesModeAttr = b.aesModeAttr return newSig } @@ -154,15 +268,73 @@ func (b *builtinAesEncryptSig) evalString(row chunk.Row) (string, bool, error) { if isNull || err != nil { return "", true, errors.Trace(err) } + keyStr, isNull, err := b.args[1].EvalString(b.ctx, row) + if isNull || err != nil { + return "", true, errors.Trace(err) + } + if !b.ivRequired && len(b.args) == 3 { + // For modes that do not require init_vector, it is ignored and a warning is generated if it is specified. + b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnOptionIgnored.GenWithStackByArgs("IV")) + } + + key := encrypt.DeriveKeyMySQL([]byte(keyStr), b.keySize) + var cipherText []byte + switch b.modeName { + case "ecb": + cipherText, err = encrypt.AESEncryptWithECB([]byte(str), key) + default: + return "", true, errors.Errorf("unsupported block encryption mode - %v", b.modeName) + } + if err != nil { + return "", true, nil + } + return string(cipherText), false, nil +} + +type builtinAesEncryptIVSig struct { + baseBuiltinFunc + *aesModeAttr +} + +func (b *builtinAesEncryptIVSig) Clone() builtinFunc { + newSig := &builtinAesEncryptIVSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + newSig.aesModeAttr = b.aesModeAttr + return newSig +} + +// evalString evals AES_ENCRYPT(str, key_str, iv). +// See https://dev.mysql.com/doc/refman/5.7/en/encryption-functions.html#function_aes-decrypt +func (b *builtinAesEncryptIVSig) evalString(row chunk.Row) (string, bool, error) { + // According to doc: If either function argument is NULL, the function returns NULL. + str, isNull, err := b.args[0].EvalString(b.ctx, row) + if isNull || err != nil { + return "", true, errors.Trace(err) + } keyStr, isNull, err := b.args[1].EvalString(b.ctx, row) if isNull || err != nil { return "", true, errors.Trace(err) } - // TODO: Support other modes. - key := encrypt.DeriveKeyMySQL([]byte(keyStr), aes128ecbBlobkSize) - cipherText, err := encrypt.AESEncryptWithECB([]byte(str), key) + iv, isNull, err := b.args[2].EvalString(b.ctx, row) + if isNull || err != nil { + return "", true, errors.Trace(err) + } + if len(iv) < aes.BlockSize { + return "", true, errIncorrectArgs.GenWithStack("The initialization vector supplied to aes_encrypt is too short. Must be at least %d bytes long", aes.BlockSize) + } + // init_vector must be 16 bytes or longer (bytes in excess of 16 are ignored) + iv = iv[0:aes.BlockSize] + + key := encrypt.DeriveKeyMySQL([]byte(keyStr), b.keySize) + var cipherText []byte + switch b.modeName { + case "cbc": + cipherText, err = encrypt.AESEncryptWithCBC([]byte(str), key, []byte(iv)) + default: + return "", true, errors.Errorf("unsupported block encryption mode - %v", b.modeName) + } if err != nil { return "", true, nil } diff --git a/expression/builtin_encryption_test.go b/expression/builtin_encryption_test.go index 6fa7e90e22dc8..99f29ea972933 100644 --- a/expression/builtin_encryption_test.go +++ b/expression/builtin_encryption_test.go @@ -19,6 +19,7 @@ import ( . "github.com/pingcap/check" "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/terror" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" @@ -77,47 +78,69 @@ func (s *testEvaluatorSuite) TestSQLEncode(c *C) { } var aesTests = []struct { + mode string origin interface{} - key interface{} + params []interface{} crypt interface{} }{ - {"pingcap", "1234567890123456", "697BFE9B3F8C2F289DD82C88C7BC95C4"}, - {"pingcap123", "1234567890123456", "CEC348F4EF5F84D3AA6C4FA184C65766"}, - {"pingcap", "123456789012345678901234", "6F1589686860C8E8C7A40A78B25FF2C0"}, - {"pingcap", "123", "996E0CA8688D7AD20819B90B273E01C6"}, - {"pingcap", 123, "996E0CA8688D7AD20819B90B273E01C6"}, - {nil, 123, nil}, + // test for ecb + {"aes-128-ecb", "pingcap", []interface{}{"1234567890123456"}, "697BFE9B3F8C2F289DD82C88C7BC95C4"}, + {"aes-128-ecb", "pingcap123", []interface{}{"1234567890123456"}, "CEC348F4EF5F84D3AA6C4FA184C65766"}, + {"aes-128-ecb", "pingcap", []interface{}{"123456789012345678901234"}, "6F1589686860C8E8C7A40A78B25FF2C0"}, + {"aes-128-ecb", "pingcap", []interface{}{"123"}, "996E0CA8688D7AD20819B90B273E01C6"}, + {"aes-128-ecb", "pingcap", []interface{}{123}, "996E0CA8688D7AD20819B90B273E01C6"}, + {"aes-128-ecb", nil, []interface{}{123}, nil}, + {"aes-192-ecb", "pingcap", []interface{}{"1234567890123456"}, "9B139FD002E6496EA2D5C73A2265E661"}, + {"aes-256-ecb", "pingcap", []interface{}{"1234567890123456"}, "F80DCDEDDBE5663BDB68F74AEDDB8EE3"}, + // test for cbc + {"aes-128-cbc", "pingcap", []interface{}{"1234567890123456", "1234567890123456"}, "2ECA0077C5EA5768A0485AA522774792"}, + {"aes-128-cbc", "pingcap", []interface{}{"123456789012345678901234", "1234567890123456"}, "483788634DA8817423BA0934FD2C096E"}, + {"aes-192-cbc", "pingcap", []interface{}{"1234567890123456", "1234567890123456"}, "516391DB38E908ECA93AAB22870EC787"}, + {"aes-256-cbc", "pingcap", []interface{}{"1234567890123456", "1234567890123456"}, "5D0E22C1E77523AEF5C3E10B65653C8F"}, + {"aes-256-cbc", "pingcap", []interface{}{"12345678901234561234567890123456", "1234567890123456"}, "A26BA27CA4BE9D361D545AA84A17002D"}, + {"aes-256-cbc", "pingcap", []interface{}{"1234567890123456", "12345678901234561234567890123456"}, "5D0E22C1E77523AEF5C3E10B65653C8F"}, } func (s *testEvaluatorSuite) TestAESEncrypt(c *C) { defer testleak.AfterTest(c)() fc := funcs[ast.AesEncrypt] for _, tt := range aesTests { - str := types.NewDatum(tt.origin) - key := types.NewDatum(tt.key) - f, err := fc.getFunction(s.ctx, s.datumsToConstants([]types.Datum{str, key})) + variable.SetSessionSystemVar(s.ctx.GetSessionVars(), variable.BlockEncryptionMode, types.NewDatum(tt.mode)) + args := []types.Datum{types.NewDatum(tt.origin)} + for _, param := range tt.params { + args = append(args, types.NewDatum(param)) + } + f, err := fc.getFunction(s.ctx, s.datumsToConstants(args)) crypt, err := evalBuiltinFunc(f, chunk.Row{}) c.Assert(err, IsNil) c.Assert(toHex(crypt), DeepEquals, types.NewDatum(tt.crypt)) } - s.testNullInput(c, ast.AesDecrypt) + variable.SetSessionSystemVar(s.ctx.GetSessionVars(), variable.BlockEncryptionMode, types.NewDatum("aes-128-ecb")) + s.testNullInput(c, ast.AesEncrypt) + s.testAmbiguousInput(c, ast.AesEncrypt) } func (s *testEvaluatorSuite) TestAESDecrypt(c *C) { defer testleak.AfterTest(c)() fc := funcs[ast.AesDecrypt] - for _, test := range aesTests { - cryptStr := fromHex(test.crypt) - key := types.NewDatum(test.key) - f, err := fc.getFunction(s.ctx, s.datumsToConstants([]types.Datum{cryptStr, key})) + for _, tt := range aesTests { + variable.SetSessionSystemVar(s.ctx.GetSessionVars(), variable.BlockEncryptionMode, types.NewDatum(tt.mode)) + args := []types.Datum{fromHex(tt.crypt)} + for _, param := range tt.params { + args = append(args, types.NewDatum(param)) + } + f, err := fc.getFunction(s.ctx, s.datumsToConstants(args)) str, err := evalBuiltinFunc(f, chunk.Row{}) c.Assert(err, IsNil) - c.Assert(str, DeepEquals, types.NewDatum(test.origin)) + c.Assert(str, DeepEquals, types.NewDatum(tt.origin)) } + variable.SetSessionSystemVar(s.ctx.GetSessionVars(), variable.BlockEncryptionMode, types.NewDatum("aes-128-ecb")) s.testNullInput(c, ast.AesDecrypt) + s.testAmbiguousInput(c, ast.AesDecrypt) } func (s *testEvaluatorSuite) testNullInput(c *C, fnName string) { + variable.SetSessionSystemVar(s.ctx.GetSessionVars(), variable.BlockEncryptionMode, types.NewDatum("aes-128-ecb")) fc := funcs[fnName] arg := types.NewStringDatum("str") var argNull types.Datum @@ -132,6 +155,28 @@ func (s *testEvaluatorSuite) testNullInput(c *C, fnName string) { c.Assert(crypt.IsNull(), IsTrue) } +func (s *testEvaluatorSuite) testAmbiguousInput(c *C, fnName string) { + fc := funcs[fnName] + arg := types.NewStringDatum("str") + // test for modes that require init_vector + variable.SetSessionSystemVar(s.ctx.GetSessionVars(), variable.BlockEncryptionMode, types.NewDatum("aes-128-cbc")) + _, err := fc.getFunction(s.ctx, s.datumsToConstants([]types.Datum{arg, arg})) + c.Assert(err, NotNil) + f, err := fc.getFunction(s.ctx, s.datumsToConstants([]types.Datum{arg, arg, types.NewStringDatum("iv < 16 bytes")})) + c.Assert(err, IsNil) + _, err = evalBuiltinFunc(f, chunk.Row{}) + c.Assert(err, NotNil) + + // test for modes that do not require init_vector + variable.SetSessionSystemVar(s.ctx.GetSessionVars(), variable.BlockEncryptionMode, types.NewDatum("aes-128-ecb")) + f, err = fc.getFunction(s.ctx, s.datumsToConstants([]types.Datum{arg, arg, arg})) + c.Assert(err, IsNil) + _, err = evalBuiltinFunc(f, chunk.Row{}) + c.Assert(err, IsNil) + warnings := s.ctx.GetSessionVars().StmtCtx.GetWarnings() + c.Assert(len(warnings), GreaterEqual, 1) +} + func toHex(d types.Datum) (h types.Datum) { if d.IsNull() { return diff --git a/expression/errors.go b/expression/errors.go index 92be6e1487054..517398bd0a9a3 100644 --- a/expression/errors.go +++ b/expression/errors.go @@ -38,6 +38,7 @@ var ( errDeprecatedSyntaxNoReplacement = terror.ClassExpression.New(mysql.ErrWarnDeprecatedSyntaxNoReplacement, mysql.MySQLErrName[mysql.ErrWarnDeprecatedSyntaxNoReplacement]) errBadField = terror.ClassExpression.New(mysql.ErrBadField, mysql.MySQLErrName[mysql.ErrBadField]) errWarnAllowedPacketOverflowed = terror.ClassExpression.New(mysql.ErrWarnAllowedPacketOverflowed, mysql.MySQLErrName[mysql.ErrWarnAllowedPacketOverflowed]) + errWarnOptionIgnored = terror.ClassExpression.New(mysql.WarnOptionIgnored, mysql.MySQLErrName[mysql.WarnOptionIgnored]) errTruncatedWrongValue = terror.ClassExpression.New(mysql.ErrTruncatedWrongValue, mysql.MySQLErrName[mysql.ErrTruncatedWrongValue]) ) @@ -54,6 +55,7 @@ func init() { mysql.ErrOperandColumns: mysql.ErrOperandColumns, mysql.ErrRegexp: mysql.ErrRegexp, mysql.ErrWarnAllowedPacketOverflowed: mysql.ErrWarnAllowedPacketOverflowed, + mysql.WarnOptionIgnored: mysql.WarnOptionIgnored, mysql.ErrTruncatedWrongValue: mysql.ErrTruncatedWrongValue, } terror.ErrClassToMySQLCodes[terror.ClassExpression] = expressionMySQLErrCodes diff --git a/expression/integration_test.go b/expression/integration_test.go index 376a69f71f398..520e8c2f2ece7 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -989,16 +989,31 @@ func (s *testIntegrationSuite) TestEncryptionBuiltin(c *C) { tk.MustExec("drop table if exists t") tk.MustExec("create table t(a char(10), b int, c double, d datetime, e time, f bit(4), g binary(20), h blob(10), i text(30))") tk.MustExec(`insert into t values('2', 2, 2.3, "2017-01-01 12:01:01", "12:01:01", 0b1010, "512", "48", "tidb")`) + tk.MustExec("SET block_encryption_mode='aes-128-ecb';") result = tk.MustQuery("select HEX(AES_ENCRYPT(a, 'key')), HEX(AES_ENCRYPT(b, 'key')), HEX(AES_ENCRYPT(c, 'key')), HEX(AES_ENCRYPT(d, 'key')), HEX(AES_ENCRYPT(e, 'key')), HEX(AES_ENCRYPT(f, 'key')), HEX(AES_ENCRYPT(g, 'key')), HEX(AES_ENCRYPT(h, 'key')), HEX(AES_ENCRYPT(i, 'key')) from t") result.Check(testkit.Rows("B3800B3A3CB4ECE2051A3E80FE373EAC B3800B3A3CB4ECE2051A3E80FE373EAC 9E018F7F2838DBA23C57F0E4CCF93287 E764D3E9D4AF8F926CD0979DDB1D0AF40C208B20A6C39D5D028644885280973A C452FFEEB76D3F5E9B26B8D48F7A228C 181BD5C81CBD36779A3C9DD5FF486B35 CE15F14AC7FF4E56ECCF148DE60E4BEDBDB6900AD51383970A5F32C59B3AC6E3 E1B29995CCF423C75519790F54A08CD2 84525677E95AC97698D22E1125B67E92")) result = tk.MustQuery("select HEX(AES_ENCRYPT('123', 'foobar')), HEX(AES_ENCRYPT(123, 'foobar')), HEX(AES_ENCRYPT('', 'foobar')), HEX(AES_ENCRYPT('你好', 'foobar')), AES_ENCRYPT(NULL, 'foobar')") result.Check(testkit.Rows(`45ABDD5C4802EFA6771A94C43F805208 45ABDD5C4802EFA6771A94C43F805208 791F1AEB6A6B796E6352BF381895CA0E D0147E2EB856186F146D9F6DE33F9546 `)) + result = tk.MustQuery("select HEX(AES_ENCRYPT(a, 'key', 'iv')), HEX(AES_ENCRYPT(b, 'key', 'iv')) from t") + result.Check(testkit.Rows("B3800B3A3CB4ECE2051A3E80FE373EAC B3800B3A3CB4ECE2051A3E80FE373EAC")) + tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", "Warning|1618| option ignored", "Warning|1618| option ignored")) + tk.MustExec("SET block_encryption_mode='aes-128-cbc';") + result = tk.MustQuery("select HEX(AES_ENCRYPT(a, 'key', '1234567890123456')), HEX(AES_ENCRYPT(b, 'key', '1234567890123456')), HEX(AES_ENCRYPT(c, 'key', '1234567890123456')), HEX(AES_ENCRYPT(d, 'key', '1234567890123456')), HEX(AES_ENCRYPT(e, 'key', '1234567890123456')), HEX(AES_ENCRYPT(f, 'key', '1234567890123456')), HEX(AES_ENCRYPT(g, 'key', '1234567890123456')), HEX(AES_ENCRYPT(h, 'key', '1234567890123456')), HEX(AES_ENCRYPT(i, 'key', '1234567890123456')) from t") + result.Check(testkit.Rows("341672829F84CB6B0BE690FEC4C4DAE9 341672829F84CB6B0BE690FEC4C4DAE9 D43734E147A12BB96C6897C4BBABA283 16F2C972411948DCEF3659B726D2CCB04AD1379A1A367FA64242058A50211B67 41E71D0C58967C1F50EEC074523946D1 1117D292E2D39C3EAA3B435371BE56FC 8ACB7ECC0883B672D7BD1CFAA9FA5FAF5B731ADE978244CD581F114D591C2E7E D2B13C30937E3251AEDA73859BA32E4B 2CF4A6051FF248A67598A17AA2C17267")) + result = tk.MustQuery("select HEX(AES_ENCRYPT('123', 'foobar', '1234567890123456')), HEX(AES_ENCRYPT(123, 'foobar', '1234567890123456')), HEX(AES_ENCRYPT('', 'foobar', '1234567890123456')), HEX(AES_ENCRYPT('你好', 'foobar', '1234567890123456')), AES_ENCRYPT(NULL, 'foobar', '1234567890123456')") + result.Check(testkit.Rows(`80D5646F07B4654B05A02D9085759770 80D5646F07B4654B05A02D9085759770 B3C14BA15030D2D7E99376DBE011E752 0CD2936EE4FEC7A8CDF6208438B2BC05 `)) // for AES_DECRYPT + tk.MustExec("SET block_encryption_mode='aes-128-ecb';") result = tk.MustQuery("select AES_DECRYPT(AES_ENCRYPT('foo', 'bar'), 'bar')") result.Check(testkit.Rows("foo")) result = tk.MustQuery("select AES_DECRYPT(UNHEX('45ABDD5C4802EFA6771A94C43F805208'), 'foobar'), AES_DECRYPT(UNHEX('791F1AEB6A6B796E6352BF381895CA0E'), 'foobar'), AES_DECRYPT(UNHEX('D0147E2EB856186F146D9F6DE33F9546'), 'foobar'), AES_DECRYPT(NULL, 'foobar'), AES_DECRYPT('SOME_THING_STRANGE', 'foobar')") result.Check(testkit.Rows(`123 你好 `)) + tk.MustExec("SET block_encryption_mode='aes-128-cbc';") + result = tk.MustQuery("select AES_DECRYPT(AES_ENCRYPT('foo', 'bar', '1234567890123456'), 'bar', '1234567890123456')") + result.Check(testkit.Rows("foo")) + result = tk.MustQuery("select AES_DECRYPT(UNHEX('80D5646F07B4654B05A02D9085759770'), 'foobar', '1234567890123456'), AES_DECRYPT(UNHEX('B3C14BA15030D2D7E99376DBE011E752'), 'foobar', '1234567890123456'), AES_DECRYPT(UNHEX('0CD2936EE4FEC7A8CDF6208438B2BC05'), 'foobar', '1234567890123456'), AES_DECRYPT(NULL, 'foobar', '1234567890123456'), AES_DECRYPT('SOME_THING_STRANGE', 'foobar', '1234567890123456')") + result.Check(testkit.Rows(`123 你好 `)) // for COMPRESS tk.MustExec("DROP TABLE IF EXISTS t1;") diff --git a/session/session.go b/session/session.go index 6ced52fe474cc..df40b70463e50 100644 --- a/session/session.go +++ b/session/session.go @@ -1258,6 +1258,7 @@ const loadCommonGlobalVarsSQL = "select HIGH_PRIORITY * from mysql.global_variab variable.SQLModeVar + quoteCommaQuote + variable.MaxAllowedPacket + quoteCommaQuote + variable.TimeZone + quoteCommaQuote + + variable.BlockEncryptionMode + quoteCommaQuote + /* TiDB specific global variables: */ variable.TiDBSkipUTF8Check + quoteCommaQuote + variable.TiDBIndexJoinBatchSize + quoteCommaQuote + diff --git a/sessionctx/variable/sysvar.go b/sessionctx/variable/sysvar.go index ebc118b7722b8..e7e27eaeb919a 100644 --- a/sessionctx/variable/sysvar.go +++ b/sessionctx/variable/sysvar.go @@ -246,7 +246,7 @@ var defaultSysVars = []*SysVar{ {ScopeNone, "myisam_mmap_size", "18446744073709551615"}, {ScopeGlobal, "init_slave", ""}, {ScopeNone, "innodb_buffer_pool_instances", "8"}, - {ScopeGlobal | ScopeSession, "block_encryption_mode", "aes-128-ecb"}, + {ScopeGlobal | ScopeSession, BlockEncryptionMode, "aes-128-ecb"}, {ScopeGlobal | ScopeSession, "max_length_for_sort_data", "1024"}, {ScopeNone, "character_set_system", "utf8"}, {ScopeGlobal | ScopeSession, "interactive_timeout", "28800"}, @@ -766,6 +766,8 @@ const ( ConnectTimeout = "connect_timeout" // SyncBinlog is the name for 'sync_binlog' system variable. SyncBinlog = "sync_binlog" + // BlockEncryptionMode is the name for 'block_encryption_mode' system variable. + BlockEncryptionMode = "block_encryption_mode" ) // GlobalVarAccessor is the interface for accessing global scope system and status variables. diff --git a/util/encrypt/aes.go b/util/encrypt/aes.go index 3f0e77061a4bd..b1b90f3a524b8 100644 --- a/util/encrypt/aes.go +++ b/util/encrypt/aes.go @@ -124,18 +124,8 @@ func AESEncryptWithECB(str, key []byte) ([]byte, error) { if err != nil { return nil, errors.Trace(err) } - blockSize := cb.BlockSize() - // The str arguments can be any length, and padding is automatically added to - // str so it is a multiple of a block as required by block-based algorithms such as AES. - // This padding is automatically removed by the AES_DECRYPT() function. - data, err := PKCS7Pad(str, blockSize) - if err != nil { - return nil, err - } - crypted := make([]byte, len(data)) - ecb := newECBEncrypter(cb) - ecb.CryptBlocks(crypted, data) - return crypted, nil + mode := newECBEncrypter(cb) + return aesEncrypt(str, mode) } // AESDecryptWithECB decrypts data using AES with ECB mode. @@ -144,18 +134,8 @@ func AESDecryptWithECB(cryptStr, key []byte) ([]byte, error) { if err != nil { return nil, errors.Trace(err) } - blockSize := cb.BlockSize() - if len(cryptStr)%blockSize != 0 { - return nil, errors.New("Corrupted data") - } mode := newECBDecrypter(cb) - data := make([]byte, len(cryptStr)) - mode.CryptBlocks(data, cryptStr) - plain, err := PKCS7Unpad(data, blockSize) - if err != nil { - return nil, err - } - return plain, nil + return aesDecrypt(cryptStr, mode) } // DeriveKeyMySQL derives the encryption key from a password in MySQL algorithm. @@ -172,3 +152,53 @@ func DeriveKeyMySQL(key []byte, blockSize int) []byte { } return rKey } + +// AESEncryptWithCBC encrypts data using AES with CBC mode. +func AESEncryptWithCBC(str, key []byte, iv []byte) ([]byte, error) { + cb, err := aes.NewCipher(key) + if err != nil { + return nil, errors.Trace(err) + } + mode := cipher.NewCBCEncrypter(cb, iv) + return aesEncrypt(str, mode) +} + +// AESDecryptWithCBC decrypts data using AES with CBC mode. +func AESDecryptWithCBC(cryptStr, key []byte, iv []byte) ([]byte, error) { + cb, err := aes.NewCipher(key) + if err != nil { + return nil, errors.Trace(err) + } + mode := cipher.NewCBCDecrypter(cb, iv) + return aesDecrypt(cryptStr, mode) +} + +// aesDecrypt decrypts data using AES. +func aesDecrypt(cryptStr []byte, mode cipher.BlockMode) ([]byte, error) { + blockSize := mode.BlockSize() + if len(cryptStr)%blockSize != 0 { + return nil, errors.New("Corrupted data") + } + data := make([]byte, len(cryptStr)) + mode.CryptBlocks(data, cryptStr) + plain, err := PKCS7Unpad(data, blockSize) + if err != nil { + return nil, err + } + return plain, nil +} + +// aesEncrypt encrypts data using AES. +func aesEncrypt(str []byte, mode cipher.BlockMode) ([]byte, error) { + blockSize := mode.BlockSize() + // The str arguments can be any length, and padding is automatically added to + // str so it is a multiple of a block as required by block-based algorithms such as AES. + // This padding is automatically removed by the AES_DECRYPT() function. + data, err := PKCS7Pad(str, blockSize) + if err != nil { + return nil, err + } + crypted := make([]byte, len(data)) + mode.CryptBlocks(crypted, data) + return crypted, nil +} diff --git a/util/encrypt/aes_test.go b/util/encrypt/aes_test.go index d1e28128e2aa6..f93b83fde4307 100644 --- a/util/encrypt/aes_test.go +++ b/util/encrypt/aes_test.go @@ -269,6 +269,79 @@ func (s *testEncryptSuite) TestAESDecryptWithECB(c *C) { } } +func (s *testEncryptSuite) TestAESEncryptWithCBC(c *C) { + defer testleak.AfterTest(c)() + tests := []struct { + str string + key string + iv string + expect string + isError bool + }{ + // 128 bits key + {"pingcap", "1234567890123456", "1234567890123456", "2ECA0077C5EA5768A0485AA522774792", false}, + {"pingcap123", "1234567890123456", "1234567890123456", "042962D340F2F95BCC07B56EAC378D3A", false}, + // 192 bits key + {"pingcap", "123456789012345678901234", "1234567890123456", "EDECE05D9FE662E381130F7F19BA67F7", false}, // 192 bit + // negtive cases: invalid key length + {"pingcap", "12345678901234567", "1234567890123456", "", true}, + {"pingcap", "123456789012345", "1234567890123456", "", true}, + } + + for _, t := range tests { + str := []byte(t.str) + key := []byte(t.key) + iv := []byte(t.iv) + + crypted, err := AESEncryptWithCBC(str, key, iv) + if t.isError { + c.Assert(err, NotNil, Commentf("%v", t)) + continue + } + c.Assert(err, IsNil, Commentf("%v", t)) + result := toHex(crypted) + c.Assert(result, Equals, t.expect, Commentf("%v", t)) + } +} + +func (s *testEncryptSuite) TestAESDecryptWithCBC(c *C) { + defer testleak.AfterTest(c)() + tests := []struct { + expect string + key string + iv string + hexCryptStr string + isError bool + }{ + // 128 bits key + {"pingcap", "1234567890123456", "1234567890123456", "2ECA0077C5EA5768A0485AA522774792", false}, + {"pingcap123", "1234567890123456", "1234567890123456", "042962D340F2F95BCC07B56EAC378D3A", false}, + // 192 bits key + {"pingcap", "123456789012345678901234", "1234567890123456", "EDECE05D9FE662E381130F7F19BA67F7", false}, // 192 bit + // negtive cases: invalid key length + {"pingcap", "12345678901234567", "1234567890123456", "", true}, + {"pingcap", "123456789012345", "1234567890123456", "", true}, + // negtive cases: invalid padding / padding size + {"", "1234567890123456", "1234567890123456", "11223344556677112233", true}, + {"", "1234567890123456", "1234567890123456", "11223344556677112233112233445566", true}, + {"", "1234567890123456", "1234567890123456", "1122334455667711223311223344556611", true}, + } + + for _, t := range tests { + cryptStr, _ := hex.DecodeString(t.hexCryptStr) + key := []byte(t.key) + iv := []byte(t.iv) + + result, err := AESDecryptWithCBC(cryptStr, key, iv) + if t.isError { + c.Assert(err, NotNil) + continue + } + c.Assert(err, IsNil) + c.Assert(string(result), Equals, t.expect) + } +} + func (s *testEncryptSuite) TestDeriveKeyMySQL(c *C) { defer testleak.AfterTest(c)()