Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: update validation and make it fail-fast #917

Merged
merged 1 commit into from
Jun 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 15 additions & 112 deletions pkg/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@ import (
"fmt"
"math/big"
"os"
"path/filepath"
"reflect"
"regexp"
"strconv"
"strings"
"time"

Expand Down Expand Up @@ -198,6 +196,11 @@ func (p *Provider) GetSecretsStoreObjectContent(ctx context.Context, attrib, sec
}
// remove whitespace from all fields in keyVaultObject
formatKeyVaultObject(&keyVaultObject)

if err = validate(keyVaultObject); err != nil {
return nil, wrapObjectTypeError(err, keyVaultObject.ObjectType, keyVaultObject.ObjectName, keyVaultObject.ObjectVersion)
}

keyVaultObjects = append(keyVaultObjects, keyVaultObject)
}

Expand All @@ -222,25 +225,6 @@ func (p *Provider) GetSecretsStoreObjectContent(ctx context.Context, attrib, sec
files := []types.SecretFile{}
for _, keyVaultObject := range keyVaultObjects {
klog.V(5).InfoS("fetching object from key vault", "objectName", keyVaultObject.ObjectName, "objectType", keyVaultObject.ObjectType, "keyvault", mc.keyvaultName, "pod", klog.ObjectRef{Namespace: podNamespace, Name: podName})
if err := validateObjectFormat(keyVaultObject.ObjectFormat, keyVaultObject.ObjectType); err != nil {
return nil, wrapObjectTypeError(err, keyVaultObject.ObjectType, keyVaultObject.ObjectName, keyVaultObject.ObjectVersion)
}
if err := validateObjectEncoding(keyVaultObject.ObjectEncoding, keyVaultObject.ObjectType); err != nil {
return nil, wrapObjectTypeError(err, keyVaultObject.ObjectType, keyVaultObject.ObjectName, keyVaultObject.ObjectVersion)
}
fileName := keyVaultObject.ObjectName
if keyVaultObject.ObjectAlias != "" {
fileName = keyVaultObject.ObjectAlias
}
if err := validateFileName(fileName); err != nil {
return nil, wrapObjectTypeError(err, keyVaultObject.ObjectType, keyVaultObject.ObjectName, keyVaultObject.ObjectVersion)
}

filePermission, err := validateFilePermission(keyVaultObject.FilePermission, defaultFilePermission)
if err != nil {
return nil, err
}

// fetch the object from Key Vault
content, newObjectVersion, err := p.getKeyVaultObjectContent(ctx, kvClient, keyVaultObject, *vaultURL)
if err != nil {
Expand All @@ -255,16 +239,17 @@ func (p *Provider) GetSecretsStoreObjectContent(ctx context.Context, attrib, sec
// objectUID is a unique identifier in the format <object type>/<object name>
// This is the object id the user sees in the SecretProviderClassPodStatus
objectUID := getObjectUID(keyVaultObject.ObjectName, keyVaultObject.ObjectType)
file := types.SecretFile{
Path: keyVaultObject.GetFileName(),
Content: objectContent,
UID: objectUID,
Version: newObjectVersion,
}
// the validity of file permission is already checked in the validate function above
file.FileMode, _ = keyVaultObject.GetFilePermission(defaultFilePermission)

// these files will be returned to the CSI driver as part of gRPC response
files = append(files, types.SecretFile{
Path: fileName,
Content: objectContent,
FileMode: filePermission,
UID: objectUID,
Version: newObjectVersion,
})
klog.V(5).InfoS("added file to the gRPC response", "file", fileName, "pod", klog.ObjectRef{Namespace: podNamespace, Name: podName})
files = append(files, file)
klog.V(5).InfoS("added file to the gRPC response", "file", file.Path, "pod", klog.ObjectRef{Namespace: podNamespace, Name: podName})
}

return files, nil
Expand Down Expand Up @@ -531,23 +516,6 @@ func setAzureEnvironmentFilePath(envFileName string) error {
return os.Setenv(azure.EnvironmentFilepathName, envFileName)
}

// validateObjectFormat checks if the object format is valid and is supported
// for the given object type
func validateObjectFormat(objectFormat, objectType string) error {
if len(objectFormat) == 0 {
return nil
}
if !strings.EqualFold(objectFormat, types.ObjectFormatPEM) && !strings.EqualFold(objectFormat, types.ObjectFormatPFX) {
return fmt.Errorf("invalid objectFormat: %v, should be PEM or PFX", objectFormat)
}
// Azure Key Vault returns the base64 encoded binary content only for type secret
// for types cert/key, the content is always in pem format
if objectFormat == types.ObjectFormatPFX && objectType != types.VaultObjectTypeSecret {
return fmt.Errorf("PFX format only supported for objectType: secret")
}
return nil
}

// getObjectVersion parses the id to retrieve the version
// of object fetched
// example id format - https://kindkv.vault.azure.net/secrets/actual/1f304204f3624873aab40231241243eb
Expand All @@ -564,25 +532,6 @@ func getObjectUID(objectName, objectType string) string {
return fmt.Sprintf("%s/%s", objectType, objectName)
}

// validateObjectEncoding checks if the object encoding is valid and is supported
// for the given object type
func validateObjectEncoding(objectEncoding, objectType string) error {
if len(objectEncoding) == 0 {
return nil
}

// ObjectEncoding is supported only for secret types
if objectType != types.VaultObjectTypeSecret {
return fmt.Errorf("objectEncoding only supported for objectType: secret")
}

if !strings.EqualFold(objectEncoding, types.ObjectEncodingHex) && !strings.EqualFold(objectEncoding, types.ObjectEncodingBase64) && !strings.EqualFold(objectEncoding, types.ObjectEncodingUtf8) {
return fmt.Errorf("invalid objectEncoding: %v, should be hex, base64 or utf-8", objectEncoding)
}

return nil
}

// getContentBytes takes the given content string and returns the bytes to write to disk
// If an encoding is specified it will decode the string first
func getContentBytes(content, objectType, objectEncoding string) ([]byte, error) {
Expand Down Expand Up @@ -620,35 +569,6 @@ func formatKeyVaultObject(object *types.KeyVaultObject) {
}
}

// This validate will make sure fileName:
// 1. is not abs path
// 2. does not contain any '..' elements
// 3. does not start with '..'
// These checks have been implemented based on -
// [validateLocalDescendingPath] https://github.com/kubernetes/kubernetes/blob/master/pkg/apis/core/validation/validation.go#L1158-L1170
// [validatePathNoBacksteps] https://github.com/kubernetes/kubernetes/blob/master/pkg/apis/core/validation/validation.go#L1172-L1186
func validateFileName(fileName string) error {
if len(fileName) == 0 {
return fmt.Errorf("file name must not be empty")
}
// is not abs path
if filepath.IsAbs(fileName) {
return fmt.Errorf("file name must be a relative path")
}
// does not have any element which is ".."
parts := strings.Split(filepath.ToSlash(fileName), "/")
for _, item := range parts {
if item == ".." {
return fmt.Errorf("file name must not contain '..'")
}
}
// fallback logic if .. is missed in the previous check
if strings.Contains(fileName, "..") {
return fmt.Errorf("file name must not contain '..'")
}
return nil
}

type node struct {
cert *x509.Certificate
parent *node
Expand Down Expand Up @@ -744,20 +664,3 @@ func fetchCertChains(data []byte) ([]byte, error) {
}
return pemData, nil
}

// validateFilePermission checks if the given file permission is correct octal number and returns
// a. decimal equivalent of the default file permission (0644) if file permission is not provided Or
// b. decimal equivalent Or
// c. error if it's not valid
func validateFilePermission(filePermission string, defaultFilePermission os.FileMode) (int32, error) {
if filePermission == "" {
return int32(defaultFilePermission), nil
}

permission, err := strconv.ParseInt(filePermission, 8, 32)
if err != nil {
return 0, fmt.Errorf("file permission must be a valid octal number: %w", err)
}

return int32(permission), nil
}
186 changes: 0 additions & 186 deletions pkg/provider/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,104 +248,6 @@ func TestParseAzureEnvironmentAzureStackCloud(t *testing.T) {
}
}

func TestValidateObjectFormat(t *testing.T) {
cases := []struct {
desc string
objectFormat string
objectType string
expectedErr error
}{
{
desc: "no object format specified",
objectFormat: "",
objectType: "cert",
expectedErr: nil,
},
{
desc: "object format not valid",
objectFormat: "pkcs",
objectType: "secret",
expectedErr: fmt.Errorf("invalid objectFormat: pkcs, should be PEM or PFX"),
},
{
desc: "object format PFX, but object type not secret",
objectFormat: "pfx",
objectType: "cert",
expectedErr: fmt.Errorf("PFX format only supported for objectType: secret"),
},
{
desc: "object format PFX case insensitive check",
objectFormat: "PFX",
objectType: "secret",
expectedErr: nil,
},
{
desc: "valid object format and type",
objectFormat: "pfx",
objectType: "secret",
expectedErr: nil,
},
}

for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
err := validateObjectFormat(tc.objectFormat, tc.objectType)
if tc.expectedErr != nil && err.Error() != tc.expectedErr.Error() || tc.expectedErr == nil && err != nil {
t.Fatalf("expected err: %+v, got: %+v", tc.expectedErr, err)
}
})
}
}

func TestValidateObjectEncoding(t *testing.T) {
cases := []struct {
desc string
objectEncoding string
objectType string
expectedErr error
}{
{
desc: "No encoding specified",
objectEncoding: "",
objectType: "cert",
expectedErr: nil,
},
{
desc: "Invalid encoding specified",
objectEncoding: "utf-16",
objectType: "secret",
expectedErr: fmt.Errorf("invalid objectEncoding: utf-16, should be hex, base64 or utf-8"),
},
{
desc: "Object Encoding Base64, but objectType is not secret",
objectEncoding: "base64",
objectType: "cert",
expectedErr: fmt.Errorf("objectEncoding only supported for objectType: secret"),
},
{
desc: "Object Encoding case-insensitive check",
objectEncoding: "BasE64",
objectType: "secret",
expectedErr: nil,
},
{
desc: "Valid ObjectEncoding and Type",
objectEncoding: "base64",
objectType: "secret",
expectedErr: nil,
},
}

for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
err := validateObjectEncoding(tc.objectEncoding, tc.objectType)
if tc.expectedErr != nil && err.Error() != tc.expectedErr.Error() || tc.expectedErr == nil && err != nil {
t.Fatalf("expected err: %+v, got: %+v", tc.expectedErr, err)
}
})
}
}

func TestGetContentBytes(t *testing.T) {
cases := []struct {
desc string
Expand Down Expand Up @@ -480,49 +382,6 @@ func TestFormatKeyVaultObject(t *testing.T) {
}
}

func TestValidateFilePath(t *testing.T) {
cases := []struct {
desc string
fileName string
expectedErr error
}{
{
desc: "file name is absolute path",
fileName: "/secret1",
expectedErr: fmt.Errorf("file name must be a relative path"),
},
{
desc: "file name contains '..'",
fileName: "secret1/..",
expectedErr: fmt.Errorf("file name must not contain '..'"),
},
{
desc: "file name starts with '..'",
fileName: "../secret1",
expectedErr: fmt.Errorf("file name must not contain '..'"),
},
{
desc: "file name is empty",
fileName: "",
expectedErr: fmt.Errorf("file name must not be empty"),
},
{
desc: "valid file name",
fileName: "secret1",
expectedErr: nil,
},
}

for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
err := validateFileName(tc.fileName)
if tc.expectedErr != nil && err.Error() != tc.expectedErr.Error() || tc.expectedErr == nil && err != nil {
t.Fatalf("expected err: %+v, got: %+v", tc.expectedErr, err)
}
})
}
}

func TestFetchCertChain(t *testing.T) {
rootCACert := `
-----BEGIN CERTIFICATE-----
Expand Down Expand Up @@ -1050,48 +909,3 @@ func TestGetObjectVersion(t *testing.T) {
actual := getObjectVersion(id)
assert.Equal(t, expectedVersion, actual)
}

func TestValidateFilePermisssion(t *testing.T) {
cases := []struct {
desc string
filePermission string
defaultFilePermission os.FileMode
isErrorExpected bool
}{
{
desc: "valid file permission",
filePermission: "0600",
defaultFilePermission: os.FileMode(0644),
isErrorExpected: false,
},
{
desc: "empty file permission",
filePermission: "",
defaultFilePermission: os.FileMode(0644),
isErrorExpected: false,
},
{
desc: "invalid file permission",
filePermission: "0900",
defaultFilePermission: os.FileMode(0644),
isErrorExpected: true,
},
{
desc: "invalid octal number",
filePermission: "900",
defaultFilePermission: os.FileMode(0644),
isErrorExpected: true,
},
}

for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
_, err := validateFilePermission(tc.filePermission, tc.defaultFilePermission)
if tc.isErrorExpected {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
Loading