From db1fb1f254125991158e18176b792436ae2fd222 Mon Sep 17 00:00:00 2001 From: crazycs520 Date: Mon, 18 Feb 2019 16:50:00 +0800 Subject: [PATCH] expression: ECB/CBC modes with 128/192/256-bit key length for AES (#7425) --- expression/builtin_encryption.go | 203 ++++++++++++++++++++++++-- expression/builtin_encryption_test.go | 78 ++++++++-- expression/errors.go | 2 + expression/integration_test.go | 15 ++ session/session.go | 1 + sessionctx/variable/sysvar.go | 4 +- util/encrypt/aes.go | 76 +++++++--- util/encrypt/aes_test.go | 73 +++++++++ 8 files changed, 396 insertions(+), 56 deletions(-) diff --git a/expression/builtin_encryption.go b/expression/builtin_encryption.go index bfce4fbbe52a1..44cb8638332e2 100644 --- a/expression/builtin_encryption.go +++ b/expression/builtin_encryption.go @@ -16,6 +16,7 @@ package expression import ( "bytes" "compress/zlib" + "crypto/aes" "crypto/md5" "crypto/rand" "crypto/sha1" @@ -25,10 +26,12 @@ import ( "fmt" "hash" "io" + "strings" "github.com/juju/errors" "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/auth" "github.com/pingcap/tidb/util/encrypt" @@ -56,7 +59,9 @@ var ( var ( _ builtinFunc = &builtinAesDecryptSig{} + _ builtinFunc = &builtinAesDecryptIVSig{} _ builtinFunc = &builtinAesEncryptSig{} + _ builtinFunc = &builtinAesEncryptIVSig{} _ builtinFunc = &builtinCompressSig{} _ builtinFunc = &builtinMD5Sig{} _ builtinFunc = &builtinPasswordSig{} @@ -68,9 +73,26 @@ var ( ) // TODO: support other mode -const ( - aes128ecbBlobkSize = 16 -) +const ivSize = aes.BlockSize + +// aesModeAttr indicates that the key length and iv attribute for specific block_encryption_mode. +// keySize is the key length in bits and mode is the encryption mode. +// ivRequired indicates that initialization vector is required or not. +type aesModeAttr struct { + modeName string + keySize int + ivRequired bool +} + +var aesModes = map[string]*aesModeAttr{ + //TODO support more modes, permitted mode values are: ECB, CBC, CFB1, CFB8, CFB128, OFB + "aes-128-ecb": {"ecb", 16, false}, + "aes-192-ecb": {"ecb", 24, false}, + "aes-256-ecb": {"ecb", 32, false}, + "aes-128-cbc": {"cbc", 16, true}, + "aes-192-cbc": {"cbc", 24, true}, + "aes-256-cbc": {"cbc", 32, true}, +} type aesDecryptFunctionClass struct { baseFunctionClass @@ -80,20 +102,37 @@ func (c *aesDecryptFunctionClass) getFunction(ctx sessionctx.Context, args []Exp if err := c.verifyArgs(args); err != nil { return nil, errors.Trace(c.verifyArgs(args)) } - bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETString, types.ETString, types.ETString) + argTps := make([]types.EvalType, 0, len(args)) + for range args { + argTps = append(argTps, types.ETString) + } + bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETString, argTps...) bf.tp.Flen = args[0].GetType().Flen // At most. types.SetBinChsClnFlag(bf.tp) - sig := &builtinAesDecryptSig{bf} - return sig, nil + + blockMode, _ := ctx.GetSessionVars().GetSystemVar(variable.BlockEncryptionMode) + mode, exists := aesModes[strings.ToLower(blockMode)] + if !exists { + return nil, errors.Errorf("unsupported block encryption mode - %v", blockMode) + } + if mode.ivRequired { + if len(args) != 3 { + return nil, ErrIncorrectParameterCount.GenByArgs("aes_decrypt") + } + return &builtinAesDecryptIVSig{bf, mode}, nil + } + return &builtinAesDecryptSig{bf, mode}, nil } type builtinAesDecryptSig struct { baseBuiltinFunc + *aesModeAttr } func (b *builtinAesDecryptSig) Clone() builtinFunc { newSig := &builtinAesDecryptSig{} newSig.cloneFrom(&b.baseBuiltinFunc) + newSig.aesModeAttr = b.aesModeAttr return newSig } @@ -105,15 +144,73 @@ func (b *builtinAesDecryptSig) evalString(row types.Row) (string, bool, error) { if isNull || err != nil { return "", true, errors.Trace(err) } + keyStr, isNull, err := b.args[1].EvalString(b.ctx, row) + if isNull || err != nil { + return "", true, errors.Trace(err) + } + if !b.ivRequired && len(b.args) == 3 { + // For modes that do not require init_vector, it is ignored and a warning is generated if it is specified. + b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnOptionIgnored.GenByArgs("IV")) + } + + key := encrypt.DeriveKeyMySQL([]byte(keyStr), b.keySize) + var plainText []byte + switch b.modeName { + case "ecb": + plainText, err = encrypt.AESDecryptWithECB([]byte(cryptStr), key) + default: + return "", true, errors.Errorf("unsupported block encryption mode - %v", b.modeName) + } + if err != nil { + return "", true, nil + } + return string(plainText), false, nil +} + +type builtinAesDecryptIVSig struct { + baseBuiltinFunc + *aesModeAttr +} + +func (b *builtinAesDecryptIVSig) Clone() builtinFunc { + newSig := &builtinAesDecryptIVSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + newSig.aesModeAttr = b.aesModeAttr + return newSig +} + +// evalString evals AES_DECRYPT(crypt_str, key_key, iv). +// See https://dev.mysql.com/doc/refman/5.7/en/encryption-functions.html#function_aes-decrypt +func (b *builtinAesDecryptIVSig) evalString(row types.Row) (string, bool, error) { + // According to doc: If either function argument is NULL, the function returns NULL. + cryptStr, isNull, err := b.args[0].EvalString(b.ctx, row) + if isNull || err != nil { + return "", true, errors.Trace(err) + } keyStr, isNull, err := b.args[1].EvalString(b.ctx, row) if isNull || err != nil { return "", true, errors.Trace(err) } - // TODO: Support other modes. - key := encrypt.DeriveKeyMySQL([]byte(keyStr), aes128ecbBlobkSize) - plainText, err := encrypt.AESDecryptWithECB([]byte(cryptStr), key) + iv, isNull, err := b.args[2].EvalString(b.ctx, row) + if isNull || err != nil { + return "", true, errors.Trace(err) + } + if len(iv) < aes.BlockSize { + return "", true, errIncorrectArgs.Gen("The initialization vector supplied to aes_decrypt is too short. Must be at least %d bytes long", aes.BlockSize) + } + // init_vector must be 16 bytes or longer (bytes in excess of 16 are ignored) + iv = iv[0:aes.BlockSize] + + key := encrypt.DeriveKeyMySQL([]byte(keyStr), b.keySize) + var plainText []byte + switch b.modeName { + case "cbc": + plainText, err = encrypt.AESDecryptWithCBC([]byte(cryptStr), key, []byte(iv)) + default: + return "", true, errors.Errorf("unsupported block encryption mode - %v", b.modeName) + } if err != nil { return "", true, nil } @@ -128,20 +225,37 @@ func (c *aesEncryptFunctionClass) getFunction(ctx sessionctx.Context, args []Exp if err := c.verifyArgs(args); err != nil { return nil, errors.Trace(c.verifyArgs(args)) } - bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETString, types.ETString, types.ETString) - bf.tp.Flen = aes128ecbBlobkSize * (args[0].GetType().Flen/aes128ecbBlobkSize + 1) // At most. + argTps := make([]types.EvalType, 0, len(args)) + for range args { + argTps = append(argTps, types.ETString) + } + bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETString, argTps...) + bf.tp.Flen = aes.BlockSize * (args[0].GetType().Flen/aes.BlockSize + 1) // At most. types.SetBinChsClnFlag(bf.tp) - sig := &builtinAesEncryptSig{bf} - return sig, nil + + blockMode, _ := ctx.GetSessionVars().GetSystemVar(variable.BlockEncryptionMode) + mode, exists := aesModes[strings.ToLower(blockMode)] + if !exists { + return nil, errors.Errorf("unsupported block encryption mode - %v", blockMode) + } + if mode.ivRequired { + if len(args) != 3 { + return nil, ErrIncorrectParameterCount.GenByArgs("aes_encrypt") + } + return &builtinAesEncryptIVSig{bf, mode}, nil + } + return &builtinAesEncryptSig{bf, mode}, nil } type builtinAesEncryptSig struct { baseBuiltinFunc + *aesModeAttr } func (b *builtinAesEncryptSig) Clone() builtinFunc { newSig := &builtinAesEncryptSig{} newSig.cloneFrom(&b.baseBuiltinFunc) + newSig.aesModeAttr = b.aesModeAttr return newSig } @@ -153,15 +267,72 @@ func (b *builtinAesEncryptSig) evalString(row types.Row) (string, bool, error) { if isNull || err != nil { return "", true, errors.Trace(err) } + keyStr, isNull, err := b.args[1].EvalString(b.ctx, row) + if isNull || err != nil { + return "", true, errors.Trace(err) + } + if !b.ivRequired && len(b.args) == 3 { + // For modes that do not require init_vector, it is ignored and a warning is generated if it is specified. + b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnOptionIgnored.GenByArgs("IV")) + } + + key := encrypt.DeriveKeyMySQL([]byte(keyStr), b.keySize) + var cipherText []byte + switch b.modeName { + case "ecb": + cipherText, err = encrypt.AESEncryptWithECB([]byte(str), key) + default: + return "", true, errors.Errorf("unsupported block encryption mode - %v", b.modeName) + } + if err != nil { + return "", true, nil + } + return string(cipherText), false, nil +} +type builtinAesEncryptIVSig struct { + baseBuiltinFunc + *aesModeAttr +} + +func (b *builtinAesEncryptIVSig) Clone() builtinFunc { + newSig := &builtinAesEncryptIVSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + newSig.aesModeAttr = b.aesModeAttr + return newSig +} + +// evalString evals AES_ENCRYPT(str, key_str, iv). +// See https://dev.mysql.com/doc/refman/5.7/en/encryption-functions.html#function_aes-decrypt +func (b *builtinAesEncryptIVSig) evalString(row types.Row) (string, bool, error) { + // According to doc: If either function argument is NULL, the function returns NULL. + str, isNull, err := b.args[0].EvalString(b.ctx, row) + if isNull || err != nil { + return "", true, errors.Trace(err) + } keyStr, isNull, err := b.args[1].EvalString(b.ctx, row) if isNull || err != nil { return "", true, errors.Trace(err) } - // TODO: Support other modes. - key := encrypt.DeriveKeyMySQL([]byte(keyStr), aes128ecbBlobkSize) - cipherText, err := encrypt.AESEncryptWithECB([]byte(str), key) + iv, isNull, err := b.args[2].EvalString(b.ctx, row) + if isNull || err != nil { + return "", true, errors.Trace(err) + } + if len(iv) < aes.BlockSize { + return "", true, errIncorrectArgs.Gen("The initialization vector supplied to aes_encrypt is too short. Must be at least %d bytes long", aes.BlockSize) + } + // init_vector must be 16 bytes or longer (bytes in excess of 16 are ignored) + iv = iv[0:aes.BlockSize] + + key := encrypt.DeriveKeyMySQL([]byte(keyStr), b.keySize) + var cipherText []byte + switch b.modeName { + case "cbc": + cipherText, err = encrypt.AESEncryptWithCBC([]byte(str), key, []byte(iv)) + default: + return "", true, errors.Errorf("unsupported block encryption mode - %v", b.modeName) + } if err != nil { return "", true, nil } diff --git a/expression/builtin_encryption_test.go b/expression/builtin_encryption_test.go index 3a4ed2c96168f..40c7398291b3a 100644 --- a/expression/builtin_encryption_test.go +++ b/expression/builtin_encryption_test.go @@ -15,6 +15,8 @@ package expression import ( "encoding/hex" + "github.com/pingcap/tidb/sessionctx/variable" + "github.com/pingcap/tidb/util/chunk" "strings" . "github.com/pingcap/check" @@ -26,47 +28,69 @@ import ( ) var aesTests = []struct { + mode string origin interface{} - key interface{} + params []interface{} crypt interface{} }{ - {"pingcap", "1234567890123456", "697BFE9B3F8C2F289DD82C88C7BC95C4"}, - {"pingcap123", "1234567890123456", "CEC348F4EF5F84D3AA6C4FA184C65766"}, - {"pingcap", "123456789012345678901234", "6F1589686860C8E8C7A40A78B25FF2C0"}, - {"pingcap", "123", "996E0CA8688D7AD20819B90B273E01C6"}, - {"pingcap", 123, "996E0CA8688D7AD20819B90B273E01C6"}, - {nil, 123, nil}, + // test for ecb + {"aes-128-ecb", "pingcap", []interface{}{"1234567890123456"}, "697BFE9B3F8C2F289DD82C88C7BC95C4"}, + {"aes-128-ecb", "pingcap123", []interface{}{"1234567890123456"}, "CEC348F4EF5F84D3AA6C4FA184C65766"}, + {"aes-128-ecb", "pingcap", []interface{}{"123456789012345678901234"}, "6F1589686860C8E8C7A40A78B25FF2C0"}, + {"aes-128-ecb", "pingcap", []interface{}{"123"}, "996E0CA8688D7AD20819B90B273E01C6"}, + {"aes-128-ecb", "pingcap", []interface{}{123}, "996E0CA8688D7AD20819B90B273E01C6"}, + {"aes-128-ecb", nil, []interface{}{123}, nil}, + {"aes-192-ecb", "pingcap", []interface{}{"1234567890123456"}, "9B139FD002E6496EA2D5C73A2265E661"}, + {"aes-256-ecb", "pingcap", []interface{}{"1234567890123456"}, "F80DCDEDDBE5663BDB68F74AEDDB8EE3"}, + // test for cbc + {"aes-128-cbc", "pingcap", []interface{}{"1234567890123456", "1234567890123456"}, "2ECA0077C5EA5768A0485AA522774792"}, + {"aes-128-cbc", "pingcap", []interface{}{"123456789012345678901234", "1234567890123456"}, "483788634DA8817423BA0934FD2C096E"}, + {"aes-192-cbc", "pingcap", []interface{}{"1234567890123456", "1234567890123456"}, "516391DB38E908ECA93AAB22870EC787"}, + {"aes-256-cbc", "pingcap", []interface{}{"1234567890123456", "1234567890123456"}, "5D0E22C1E77523AEF5C3E10B65653C8F"}, + {"aes-256-cbc", "pingcap", []interface{}{"12345678901234561234567890123456", "1234567890123456"}, "A26BA27CA4BE9D361D545AA84A17002D"}, + {"aes-256-cbc", "pingcap", []interface{}{"1234567890123456", "12345678901234561234567890123456"}, "5D0E22C1E77523AEF5C3E10B65653C8F"}, } func (s *testEvaluatorSuite) TestAESEncrypt(c *C) { defer testleak.AfterTest(c)() fc := funcs[ast.AesEncrypt] for _, tt := range aesTests { - str := types.NewDatum(tt.origin) - key := types.NewDatum(tt.key) - f, err := fc.getFunction(s.ctx, s.datumsToConstants([]types.Datum{str, key})) + variable.SetSessionSystemVar(s.ctx.GetSessionVars(), variable.BlockEncryptionMode, types.NewDatum(tt.mode)) + args := []types.Datum{types.NewDatum(tt.origin)} + for _, param := range tt.params { + args = append(args, types.NewDatum(param)) + } + f, err := fc.getFunction(s.ctx, s.datumsToConstants(args)) crypt, err := evalBuiltinFunc(f, nil) c.Assert(err, IsNil) c.Assert(toHex(crypt), DeepEquals, types.NewDatum(tt.crypt)) } - s.testNullInput(c, ast.AesDecrypt) + variable.SetSessionSystemVar(s.ctx.GetSessionVars(), variable.BlockEncryptionMode, types.NewDatum("aes-128-ecb")) + s.testNullInput(c, ast.AesEncrypt) + s.testAmbiguousInput(c, ast.AesEncrypt) } func (s *testEvaluatorSuite) TestAESDecrypt(c *C) { defer testleak.AfterTest(c)() fc := funcs[ast.AesDecrypt] - for _, test := range aesTests { - cryptStr := fromHex(test.crypt) - key := types.NewDatum(test.key) - f, err := fc.getFunction(s.ctx, s.datumsToConstants([]types.Datum{cryptStr, key})) + for _, tt := range aesTests { + variable.SetSessionSystemVar(s.ctx.GetSessionVars(), variable.BlockEncryptionMode, types.NewDatum(tt.mode)) + args := []types.Datum{fromHex(tt.crypt)} + for _, param := range tt.params { + args = append(args, types.NewDatum(param)) + } + f, err := fc.getFunction(s.ctx, s.datumsToConstants(args)) str, err := evalBuiltinFunc(f, nil) c.Assert(err, IsNil) - c.Assert(str, DeepEquals, types.NewDatum(test.origin)) + c.Assert(str, DeepEquals, types.NewDatum(tt.origin)) } + variable.SetSessionSystemVar(s.ctx.GetSessionVars(), variable.BlockEncryptionMode, types.NewDatum("aes-128-ecb")) s.testNullInput(c, ast.AesDecrypt) + s.testAmbiguousInput(c, ast.AesDecrypt) } func (s *testEvaluatorSuite) testNullInput(c *C, fnName string) { + variable.SetSessionSystemVar(s.ctx.GetSessionVars(), variable.BlockEncryptionMode, types.NewDatum("aes-128-ecb")) fc := funcs[fnName] arg := types.NewStringDatum("str") var argNull types.Datum @@ -81,6 +105,28 @@ func (s *testEvaluatorSuite) testNullInput(c *C, fnName string) { c.Assert(crypt.IsNull(), IsTrue) } +func (s *testEvaluatorSuite) testAmbiguousInput(c *C, fnName string) { + fc := funcs[fnName] + arg := types.NewStringDatum("str") + // test for modes that require init_vector + variable.SetSessionSystemVar(s.ctx.GetSessionVars(), variable.BlockEncryptionMode, types.NewDatum("aes-128-cbc")) + _, err := fc.getFunction(s.ctx, s.datumsToConstants([]types.Datum{arg, arg})) + c.Assert(err, NotNil) + f, err := fc.getFunction(s.ctx, s.datumsToConstants([]types.Datum{arg, arg, types.NewStringDatum("iv < 16 bytes")})) + c.Assert(err, IsNil) + _, err = evalBuiltinFunc(f, chunk.Row{}) + c.Assert(err, NotNil) + + // test for modes that do not require init_vector + variable.SetSessionSystemVar(s.ctx.GetSessionVars(), variable.BlockEncryptionMode, types.NewDatum("aes-128-ecb")) + f, err = fc.getFunction(s.ctx, s.datumsToConstants([]types.Datum{arg, arg, arg})) + c.Assert(err, IsNil) + _, err = evalBuiltinFunc(f, chunk.Row{}) + c.Assert(err, IsNil) + warnings := s.ctx.GetSessionVars().StmtCtx.GetWarnings() + c.Assert(len(warnings), GreaterEqual, 1) +} + func toHex(d types.Datum) (h types.Datum) { if d.IsNull() { return diff --git a/expression/errors.go b/expression/errors.go index 951ae09ce5518..2a63a14ab67f6 100644 --- a/expression/errors.go +++ b/expression/errors.go @@ -34,6 +34,7 @@ var ( errDefaultValue = terror.ClassExpression.New(mysql.ErrInvalidDefault, "invalid default value") errDeprecatedSyntaxNoReplacement = terror.ClassExpression.New(mysql.ErrWarnDeprecatedSyntaxNoReplacement, mysql.MySQLErrName[mysql.ErrWarnDeprecatedSyntaxNoReplacement]) errWarnAllowedPacketOverflowed = terror.ClassExpression.New(mysql.ErrWarnAllowedPacketOverflowed, mysql.MySQLErrName[mysql.ErrWarnAllowedPacketOverflowed]) + errWarnOptionIgnored = terror.ClassExpression.New(mysql.WarnOptionIgnored, mysql.MySQLErrName[mysql.WarnOptionIgnored]) ) func init() { @@ -47,6 +48,7 @@ func init() { mysql.ErrInvalidDefault: mysql.ErrInvalidDefault, mysql.ErrWarnDeprecatedSyntaxNoReplacement: mysql.ErrWarnDeprecatedSyntaxNoReplacement, mysql.ErrWarnAllowedPacketOverflowed: mysql.ErrWarnAllowedPacketOverflowed, + mysql.WarnOptionIgnored: mysql.WarnOptionIgnored, } terror.ErrClassToMySQLCodes[terror.ClassExpression] = expressionMySQLErrCodes } diff --git a/expression/integration_test.go b/expression/integration_test.go index b84f37c572701..e5b9fb18d4734 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -971,16 +971,31 @@ func (s *testIntegrationSuite) TestEncryptionBuiltin(c *C) { tk.MustExec("drop table if exists t") tk.MustExec("create table t(a char(10), b int, c double, d datetime, e time, f bit(4), g binary(20), h blob(10), i text(30))") tk.MustExec(`insert into t values('2', 2, 2.3, "2017-01-01 12:01:01", "12:01:01", 0b1010, "512", "48", "tidb")`) + tk.MustExec("SET block_encryption_mode='aes-128-ecb';") result = tk.MustQuery("select HEX(AES_ENCRYPT(a, 'key')), HEX(AES_ENCRYPT(b, 'key')), HEX(AES_ENCRYPT(c, 'key')), HEX(AES_ENCRYPT(d, 'key')), HEX(AES_ENCRYPT(e, 'key')), HEX(AES_ENCRYPT(f, 'key')), HEX(AES_ENCRYPT(g, 'key')), HEX(AES_ENCRYPT(h, 'key')), HEX(AES_ENCRYPT(i, 'key')) from t") result.Check(testkit.Rows("B3800B3A3CB4ECE2051A3E80FE373EAC B3800B3A3CB4ECE2051A3E80FE373EAC 9E018F7F2838DBA23C57F0E4CCF93287 E764D3E9D4AF8F926CD0979DDB1D0AF40C208B20A6C39D5D028644885280973A C452FFEEB76D3F5E9B26B8D48F7A228C 181BD5C81CBD36779A3C9DD5FF486B35 CE15F14AC7FF4E56ECCF148DE60E4BEDBDB6900AD51383970A5F32C59B3AC6E3 E1B29995CCF423C75519790F54A08CD2 84525677E95AC97698D22E1125B67E92")) result = tk.MustQuery("select HEX(AES_ENCRYPT('123', 'foobar')), HEX(AES_ENCRYPT(123, 'foobar')), HEX(AES_ENCRYPT('', 'foobar')), HEX(AES_ENCRYPT('你好', 'foobar')), AES_ENCRYPT(NULL, 'foobar')") result.Check(testkit.Rows(`45ABDD5C4802EFA6771A94C43F805208 45ABDD5C4802EFA6771A94C43F805208 791F1AEB6A6B796E6352BF381895CA0E D0147E2EB856186F146D9F6DE33F9546 `)) + result = tk.MustQuery("select HEX(AES_ENCRYPT(a, 'key', 'iv')), HEX(AES_ENCRYPT(b, 'key', 'iv')) from t") + result.Check(testkit.Rows("B3800B3A3CB4ECE2051A3E80FE373EAC B3800B3A3CB4ECE2051A3E80FE373EAC")) + tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", "Warning|1618| option ignored", "Warning|1618| option ignored")) + tk.MustExec("SET block_encryption_mode='aes-128-cbc';") + result = tk.MustQuery("select HEX(AES_ENCRYPT(a, 'key', '1234567890123456')), HEX(AES_ENCRYPT(b, 'key', '1234567890123456')), HEX(AES_ENCRYPT(c, 'key', '1234567890123456')), HEX(AES_ENCRYPT(d, 'key', '1234567890123456')), HEX(AES_ENCRYPT(e, 'key', '1234567890123456')), HEX(AES_ENCRYPT(f, 'key', '1234567890123456')), HEX(AES_ENCRYPT(g, 'key', '1234567890123456')), HEX(AES_ENCRYPT(h, 'key', '1234567890123456')), HEX(AES_ENCRYPT(i, 'key', '1234567890123456')) from t") + result.Check(testkit.Rows("341672829F84CB6B0BE690FEC4C4DAE9 341672829F84CB6B0BE690FEC4C4DAE9 D43734E147A12BB96C6897C4BBABA283 16F2C972411948DCEF3659B726D2CCB04AD1379A1A367FA64242058A50211B67 41E71D0C58967C1F50EEC074523946D1 1117D292E2D39C3EAA3B435371BE56FC 8ACB7ECC0883B672D7BD1CFAA9FA5FAF5B731ADE978244CD581F114D591C2E7E D2B13C30937E3251AEDA73859BA32E4B 2CF4A6051FF248A67598A17AA2C17267")) + result = tk.MustQuery("select HEX(AES_ENCRYPT('123', 'foobar', '1234567890123456')), HEX(AES_ENCRYPT(123, 'foobar', '1234567890123456')), HEX(AES_ENCRYPT('', 'foobar', '1234567890123456')), HEX(AES_ENCRYPT('你好', 'foobar', '1234567890123456')), AES_ENCRYPT(NULL, 'foobar', '1234567890123456')") + result.Check(testkit.Rows(`80D5646F07B4654B05A02D9085759770 80D5646F07B4654B05A02D9085759770 B3C14BA15030D2D7E99376DBE011E752 0CD2936EE4FEC7A8CDF6208438B2BC05 `)) // for AES_DECRYPT + tk.MustExec("SET block_encryption_mode='aes-128-ecb';") result = tk.MustQuery("select AES_DECRYPT(AES_ENCRYPT('foo', 'bar'), 'bar')") result.Check(testkit.Rows("foo")) result = tk.MustQuery("select AES_DECRYPT(UNHEX('45ABDD5C4802EFA6771A94C43F805208'), 'foobar'), AES_DECRYPT(UNHEX('791F1AEB6A6B796E6352BF381895CA0E'), 'foobar'), AES_DECRYPT(UNHEX('D0147E2EB856186F146D9F6DE33F9546'), 'foobar'), AES_DECRYPT(NULL, 'foobar'), AES_DECRYPT('SOME_THING_STRANGE', 'foobar')") result.Check(testkit.Rows(`123 你好 `)) + tk.MustExec("SET block_encryption_mode='aes-128-cbc';") + result = tk.MustQuery("select AES_DECRYPT(AES_ENCRYPT('foo', 'bar', '1234567890123456'), 'bar', '1234567890123456')") + result.Check(testkit.Rows("foo")) + result = tk.MustQuery("select AES_DECRYPT(UNHEX('80D5646F07B4654B05A02D9085759770'), 'foobar', '1234567890123456'), AES_DECRYPT(UNHEX('B3C14BA15030D2D7E99376DBE011E752'), 'foobar', '1234567890123456'), AES_DECRYPT(UNHEX('0CD2936EE4FEC7A8CDF6208438B2BC05'), 'foobar', '1234567890123456'), AES_DECRYPT(NULL, 'foobar', '1234567890123456'), AES_DECRYPT('SOME_THING_STRANGE', 'foobar', '1234567890123456')") + result.Check(testkit.Rows(`123 你好 `)) // for COMPRESS tk.MustExec("DROP TABLE IF EXISTS t1;") diff --git a/session/session.go b/session/session.go index 785369ca362ac..85ddfb6bcb10b 100644 --- a/session/session.go +++ b/session/session.go @@ -1277,6 +1277,7 @@ const loadCommonGlobalVarsSQL = "select HIGH_PRIORITY * from mysql.global_variab variable.SQLModeVar + quoteCommaQuote + variable.MaxAllowedPacket + quoteCommaQuote + variable.TimeZone + quoteCommaQuote + + variable.BlockEncryptionMode + quoteCommaQuote + /* TiDB specific global variables: */ variable.TiDBSkipUTF8Check + quoteCommaQuote + variable.TiDBIndexJoinBatchSize + quoteCommaQuote + diff --git a/sessionctx/variable/sysvar.go b/sessionctx/variable/sysvar.go index 9dd1e3f9027c6..70ce423c86a6e 100644 --- a/sessionctx/variable/sysvar.go +++ b/sessionctx/variable/sysvar.go @@ -235,7 +235,7 @@ var defaultSysVars = []*SysVar{ {ScopeNone, "myisam_mmap_size", "18446744073709551615"}, {ScopeGlobal, "init_slave", ""}, {ScopeNone, "innodb_buffer_pool_instances", "8"}, - {ScopeGlobal | ScopeSession, "block_encryption_mode", "aes-128-ecb"}, + {ScopeGlobal | ScopeSession, BlockEncryptionMode, "aes-128-ecb"}, {ScopeGlobal | ScopeSession, "max_length_for_sort_data", "1024"}, {ScopeNone, "character_set_system", "utf8"}, {ScopeGlobal | ScopeSession, "interactive_timeout", "28800"}, @@ -674,6 +674,8 @@ const ( CharsetDatabase = "character_set_database" // CollationDatabase is the name for collation_database system variable. CollationDatabase = "collation_database" + // BlockEncryptionMode is the name for 'block_encryption_mode' system variable. + BlockEncryptionMode = "block_encryption_mode" ) // GlobalVarAccessor is the interface for accessing global scope system and status variables. diff --git a/util/encrypt/aes.go b/util/encrypt/aes.go index deeea49e48976..befb86077b3ad 100644 --- a/util/encrypt/aes.go +++ b/util/encrypt/aes.go @@ -124,18 +124,8 @@ func AESEncryptWithECB(str, key []byte) ([]byte, error) { if err != nil { return nil, errors.Trace(err) } - blockSize := cb.BlockSize() - // The str arguments can be any length, and padding is automatically added to - // str so it is a multiple of a block as required by block-based algorithms such as AES. - // This padding is automatically removed by the AES_DECRYPT() function. - data, err := PKCS7Pad(str, blockSize) - if err != nil { - return nil, err - } - crypted := make([]byte, len(data)) - ecb := newECBEncrypter(cb) - ecb.CryptBlocks(crypted, data) - return crypted, nil + mode := newECBEncrypter(cb) + return aesEncrypt(str, mode) } // AESDecryptWithECB decrypts data using AES with ECB mode. @@ -144,18 +134,8 @@ func AESDecryptWithECB(cryptStr, key []byte) ([]byte, error) { if err != nil { return nil, errors.Trace(err) } - blockSize := cb.BlockSize() - if len(cryptStr)%blockSize != 0 { - return nil, errors.New("Corrupted data") - } mode := newECBDecrypter(cb) - data := make([]byte, len(cryptStr)) - mode.CryptBlocks(data, cryptStr) - plain, err := PKCS7Unpad(data, blockSize) - if err != nil { - return nil, err - } - return plain, nil + return aesDecrypt(cryptStr, mode) } // DeriveKeyMySQL derives the encryption key from a password in MySQL algorithm. @@ -172,3 +152,53 @@ func DeriveKeyMySQL(key []byte, blockSize int) []byte { } return rKey } + +// AESEncryptWithCBC encrypts data using AES with CBC mode. +func AESEncryptWithCBC(str, key []byte, iv []byte) ([]byte, error) { + cb, err := aes.NewCipher(key) + if err != nil { + return nil, errors.Trace(err) + } + mode := cipher.NewCBCEncrypter(cb, iv) + return aesEncrypt(str, mode) +} + +// AESDecryptWithCBC decrypts data using AES with CBC mode. +func AESDecryptWithCBC(cryptStr, key []byte, iv []byte) ([]byte, error) { + cb, err := aes.NewCipher(key) + if err != nil { + return nil, errors.Trace(err) + } + mode := cipher.NewCBCDecrypter(cb, iv) + return aesDecrypt(cryptStr, mode) +} + +// aesDecrypt decrypts data using AES. +func aesDecrypt(cryptStr []byte, mode cipher.BlockMode) ([]byte, error) { + blockSize := mode.BlockSize() + if len(cryptStr)%blockSize != 0 { + return nil, errors.New("Corrupted data") + } + data := make([]byte, len(cryptStr)) + mode.CryptBlocks(data, cryptStr) + plain, err := PKCS7Unpad(data, blockSize) + if err != nil { + return nil, err + } + return plain, nil +} + +// aesEncrypt encrypts data using AES. +func aesEncrypt(str []byte, mode cipher.BlockMode) ([]byte, error) { + blockSize := mode.BlockSize() + // The str arguments can be any length, and padding is automatically added to + // str so it is a multiple of a block as required by block-based algorithms such as AES. + // This padding is automatically removed by the AES_DECRYPT() function. + data, err := PKCS7Pad(str, blockSize) + if err != nil { + return nil, err + } + crypted := make([]byte, len(data)) + mode.CryptBlocks(crypted, data) + return crypted, nil +} diff --git a/util/encrypt/aes_test.go b/util/encrypt/aes_test.go index d1e28128e2aa6..f93b83fde4307 100644 --- a/util/encrypt/aes_test.go +++ b/util/encrypt/aes_test.go @@ -269,6 +269,79 @@ func (s *testEncryptSuite) TestAESDecryptWithECB(c *C) { } } +func (s *testEncryptSuite) TestAESEncryptWithCBC(c *C) { + defer testleak.AfterTest(c)() + tests := []struct { + str string + key string + iv string + expect string + isError bool + }{ + // 128 bits key + {"pingcap", "1234567890123456", "1234567890123456", "2ECA0077C5EA5768A0485AA522774792", false}, + {"pingcap123", "1234567890123456", "1234567890123456", "042962D340F2F95BCC07B56EAC378D3A", false}, + // 192 bits key + {"pingcap", "123456789012345678901234", "1234567890123456", "EDECE05D9FE662E381130F7F19BA67F7", false}, // 192 bit + // negtive cases: invalid key length + {"pingcap", "12345678901234567", "1234567890123456", "", true}, + {"pingcap", "123456789012345", "1234567890123456", "", true}, + } + + for _, t := range tests { + str := []byte(t.str) + key := []byte(t.key) + iv := []byte(t.iv) + + crypted, err := AESEncryptWithCBC(str, key, iv) + if t.isError { + c.Assert(err, NotNil, Commentf("%v", t)) + continue + } + c.Assert(err, IsNil, Commentf("%v", t)) + result := toHex(crypted) + c.Assert(result, Equals, t.expect, Commentf("%v", t)) + } +} + +func (s *testEncryptSuite) TestAESDecryptWithCBC(c *C) { + defer testleak.AfterTest(c)() + tests := []struct { + expect string + key string + iv string + hexCryptStr string + isError bool + }{ + // 128 bits key + {"pingcap", "1234567890123456", "1234567890123456", "2ECA0077C5EA5768A0485AA522774792", false}, + {"pingcap123", "1234567890123456", "1234567890123456", "042962D340F2F95BCC07B56EAC378D3A", false}, + // 192 bits key + {"pingcap", "123456789012345678901234", "1234567890123456", "EDECE05D9FE662E381130F7F19BA67F7", false}, // 192 bit + // negtive cases: invalid key length + {"pingcap", "12345678901234567", "1234567890123456", "", true}, + {"pingcap", "123456789012345", "1234567890123456", "", true}, + // negtive cases: invalid padding / padding size + {"", "1234567890123456", "1234567890123456", "11223344556677112233", true}, + {"", "1234567890123456", "1234567890123456", "11223344556677112233112233445566", true}, + {"", "1234567890123456", "1234567890123456", "1122334455667711223311223344556611", true}, + } + + for _, t := range tests { + cryptStr, _ := hex.DecodeString(t.hexCryptStr) + key := []byte(t.key) + iv := []byte(t.iv) + + result, err := AESDecryptWithCBC(cryptStr, key, iv) + if t.isError { + c.Assert(err, NotNil) + continue + } + c.Assert(err, IsNil) + c.Assert(string(result), Equals, t.expect) + } +} + func (s *testEncryptSuite) TestDeriveKeyMySQL(c *C) { defer testleak.AfterTest(c)()