Skip to content

Commit

Permalink
expression: ECB/CBC modes with 128/192/256-bit key length for AES (pi…
Browse files Browse the repository at this point in the history
  • Loading branch information
crazycs520 committed Feb 18, 2019
1 parent 29ec059 commit db1fb1f
Show file tree
Hide file tree
Showing 8 changed files with 396 additions and 56 deletions.
203 changes: 187 additions & 16 deletions expression/builtin_encryption.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package expression
import (
"bytes"
"compress/zlib"
"crypto/aes"
"crypto/md5"
"crypto/rand"
"crypto/sha1"
Expand All @@ -25,10 +26,12 @@ import (
"fmt"
"hash"
"io"
"strings"

"github.com/juju/errors"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/auth"
"github.com/pingcap/tidb/util/encrypt"
Expand Down Expand Up @@ -56,7 +59,9 @@ var (

var (
_ builtinFunc = &builtinAesDecryptSig{}
_ builtinFunc = &builtinAesDecryptIVSig{}
_ builtinFunc = &builtinAesEncryptSig{}
_ builtinFunc = &builtinAesEncryptIVSig{}
_ builtinFunc = &builtinCompressSig{}
_ builtinFunc = &builtinMD5Sig{}
_ builtinFunc = &builtinPasswordSig{}
Expand All @@ -68,9 +73,26 @@ var (
)

// TODO: support other mode
const (
aes128ecbBlobkSize = 16
)
const ivSize = aes.BlockSize

// aesModeAttr indicates that the key length and iv attribute for specific block_encryption_mode.
// keySize is the key length in bits and mode is the encryption mode.
// ivRequired indicates that initialization vector is required or not.
type aesModeAttr struct {
modeName string
keySize int
ivRequired bool
}

var aesModes = map[string]*aesModeAttr{
//TODO support more modes, permitted mode values are: ECB, CBC, CFB1, CFB8, CFB128, OFB
"aes-128-ecb": {"ecb", 16, false},
"aes-192-ecb": {"ecb", 24, false},
"aes-256-ecb": {"ecb", 32, false},
"aes-128-cbc": {"cbc", 16, true},
"aes-192-cbc": {"cbc", 24, true},
"aes-256-cbc": {"cbc", 32, true},
}

type aesDecryptFunctionClass struct {
baseFunctionClass
Expand All @@ -80,20 +102,37 @@ func (c *aesDecryptFunctionClass) getFunction(ctx sessionctx.Context, args []Exp
if err := c.verifyArgs(args); err != nil {
return nil, errors.Trace(c.verifyArgs(args))
}
bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETString, types.ETString, types.ETString)
argTps := make([]types.EvalType, 0, len(args))
for range args {
argTps = append(argTps, types.ETString)
}
bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETString, argTps...)
bf.tp.Flen = args[0].GetType().Flen // At most.
types.SetBinChsClnFlag(bf.tp)
sig := &builtinAesDecryptSig{bf}
return sig, nil

blockMode, _ := ctx.GetSessionVars().GetSystemVar(variable.BlockEncryptionMode)
mode, exists := aesModes[strings.ToLower(blockMode)]
if !exists {
return nil, errors.Errorf("unsupported block encryption mode - %v", blockMode)
}
if mode.ivRequired {
if len(args) != 3 {
return nil, ErrIncorrectParameterCount.GenByArgs("aes_decrypt")
}
return &builtinAesDecryptIVSig{bf, mode}, nil
}
return &builtinAesDecryptSig{bf, mode}, nil
}

type builtinAesDecryptSig struct {
baseBuiltinFunc
*aesModeAttr
}

func (b *builtinAesDecryptSig) Clone() builtinFunc {
newSig := &builtinAesDecryptSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig.aesModeAttr = b.aesModeAttr
return newSig
}

Expand All @@ -105,15 +144,73 @@ func (b *builtinAesDecryptSig) evalString(row types.Row) (string, bool, error) {
if isNull || err != nil {
return "", true, errors.Trace(err)
}
keyStr, isNull, err := b.args[1].EvalString(b.ctx, row)
if isNull || err != nil {
return "", true, errors.Trace(err)
}
if !b.ivRequired && len(b.args) == 3 {
// For modes that do not require init_vector, it is ignored and a warning is generated if it is specified.
b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnOptionIgnored.GenByArgs("IV"))
}

key := encrypt.DeriveKeyMySQL([]byte(keyStr), b.keySize)
var plainText []byte
switch b.modeName {
case "ecb":
plainText, err = encrypt.AESDecryptWithECB([]byte(cryptStr), key)
default:
return "", true, errors.Errorf("unsupported block encryption mode - %v", b.modeName)
}
if err != nil {
return "", true, nil
}
return string(plainText), false, nil
}

type builtinAesDecryptIVSig struct {
baseBuiltinFunc
*aesModeAttr
}

func (b *builtinAesDecryptIVSig) Clone() builtinFunc {
newSig := &builtinAesDecryptIVSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig.aesModeAttr = b.aesModeAttr
return newSig
}

// evalString evals AES_DECRYPT(crypt_str, key_key, iv).
// See https://dev.mysql.com/doc/refman/5.7/en/encryption-functions.html#function_aes-decrypt
func (b *builtinAesDecryptIVSig) evalString(row types.Row) (string, bool, error) {
// According to doc: If either function argument is NULL, the function returns NULL.
cryptStr, isNull, err := b.args[0].EvalString(b.ctx, row)
if isNull || err != nil {
return "", true, errors.Trace(err)
}

keyStr, isNull, err := b.args[1].EvalString(b.ctx, row)
if isNull || err != nil {
return "", true, errors.Trace(err)
}

// TODO: Support other modes.
key := encrypt.DeriveKeyMySQL([]byte(keyStr), aes128ecbBlobkSize)
plainText, err := encrypt.AESDecryptWithECB([]byte(cryptStr), key)
iv, isNull, err := b.args[2].EvalString(b.ctx, row)
if isNull || err != nil {
return "", true, errors.Trace(err)
}
if len(iv) < aes.BlockSize {
return "", true, errIncorrectArgs.Gen("The initialization vector supplied to aes_decrypt is too short. Must be at least %d bytes long", aes.BlockSize)
}
// init_vector must be 16 bytes or longer (bytes in excess of 16 are ignored)
iv = iv[0:aes.BlockSize]

key := encrypt.DeriveKeyMySQL([]byte(keyStr), b.keySize)
var plainText []byte
switch b.modeName {
case "cbc":
plainText, err = encrypt.AESDecryptWithCBC([]byte(cryptStr), key, []byte(iv))
default:
return "", true, errors.Errorf("unsupported block encryption mode - %v", b.modeName)
}
if err != nil {
return "", true, nil
}
Expand All @@ -128,20 +225,37 @@ func (c *aesEncryptFunctionClass) getFunction(ctx sessionctx.Context, args []Exp
if err := c.verifyArgs(args); err != nil {
return nil, errors.Trace(c.verifyArgs(args))
}
bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETString, types.ETString, types.ETString)
bf.tp.Flen = aes128ecbBlobkSize * (args[0].GetType().Flen/aes128ecbBlobkSize + 1) // At most.
argTps := make([]types.EvalType, 0, len(args))
for range args {
argTps = append(argTps, types.ETString)
}
bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETString, argTps...)
bf.tp.Flen = aes.BlockSize * (args[0].GetType().Flen/aes.BlockSize + 1) // At most.
types.SetBinChsClnFlag(bf.tp)
sig := &builtinAesEncryptSig{bf}
return sig, nil

blockMode, _ := ctx.GetSessionVars().GetSystemVar(variable.BlockEncryptionMode)
mode, exists := aesModes[strings.ToLower(blockMode)]
if !exists {
return nil, errors.Errorf("unsupported block encryption mode - %v", blockMode)
}
if mode.ivRequired {
if len(args) != 3 {
return nil, ErrIncorrectParameterCount.GenByArgs("aes_encrypt")
}
return &builtinAesEncryptIVSig{bf, mode}, nil
}
return &builtinAesEncryptSig{bf, mode}, nil
}

type builtinAesEncryptSig struct {
baseBuiltinFunc
*aesModeAttr
}

func (b *builtinAesEncryptSig) Clone() builtinFunc {
newSig := &builtinAesEncryptSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig.aesModeAttr = b.aesModeAttr
return newSig
}

Expand All @@ -153,15 +267,72 @@ func (b *builtinAesEncryptSig) evalString(row types.Row) (string, bool, error) {
if isNull || err != nil {
return "", true, errors.Trace(err)
}
keyStr, isNull, err := b.args[1].EvalString(b.ctx, row)
if isNull || err != nil {
return "", true, errors.Trace(err)
}
if !b.ivRequired && len(b.args) == 3 {
// For modes that do not require init_vector, it is ignored and a warning is generated if it is specified.
b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnOptionIgnored.GenByArgs("IV"))
}

key := encrypt.DeriveKeyMySQL([]byte(keyStr), b.keySize)
var cipherText []byte
switch b.modeName {
case "ecb":
cipherText, err = encrypt.AESEncryptWithECB([]byte(str), key)
default:
return "", true, errors.Errorf("unsupported block encryption mode - %v", b.modeName)
}
if err != nil {
return "", true, nil
}
return string(cipherText), false, nil
}

type builtinAesEncryptIVSig struct {
baseBuiltinFunc
*aesModeAttr
}

func (b *builtinAesEncryptIVSig) Clone() builtinFunc {
newSig := &builtinAesEncryptIVSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig.aesModeAttr = b.aesModeAttr
return newSig
}

// evalString evals AES_ENCRYPT(str, key_str, iv).
// See https://dev.mysql.com/doc/refman/5.7/en/encryption-functions.html#function_aes-decrypt
func (b *builtinAesEncryptIVSig) evalString(row types.Row) (string, bool, error) {
// According to doc: If either function argument is NULL, the function returns NULL.
str, isNull, err := b.args[0].EvalString(b.ctx, row)
if isNull || err != nil {
return "", true, errors.Trace(err)
}
keyStr, isNull, err := b.args[1].EvalString(b.ctx, row)
if isNull || err != nil {
return "", true, errors.Trace(err)
}

// TODO: Support other modes.
key := encrypt.DeriveKeyMySQL([]byte(keyStr), aes128ecbBlobkSize)
cipherText, err := encrypt.AESEncryptWithECB([]byte(str), key)
iv, isNull, err := b.args[2].EvalString(b.ctx, row)
if isNull || err != nil {
return "", true, errors.Trace(err)
}
if len(iv) < aes.BlockSize {
return "", true, errIncorrectArgs.Gen("The initialization vector supplied to aes_encrypt is too short. Must be at least %d bytes long", aes.BlockSize)
}
// init_vector must be 16 bytes or longer (bytes in excess of 16 are ignored)
iv = iv[0:aes.BlockSize]

key := encrypt.DeriveKeyMySQL([]byte(keyStr), b.keySize)
var cipherText []byte
switch b.modeName {
case "cbc":
cipherText, err = encrypt.AESEncryptWithCBC([]byte(str), key, []byte(iv))
default:
return "", true, errors.Errorf("unsupported block encryption mode - %v", b.modeName)
}
if err != nil {
return "", true, nil
}
Expand Down
Loading

0 comments on commit db1fb1f

Please sign in to comment.