Skip to content

Commit

Permalink
refactor: use secret file for object versions (#761)
Browse files Browse the repository at this point in the history
* refactor: use secret file for object versions

Signed-off-by: Anish Ramasekar <anish.ramasekar@gmail.com>

* refactor: update function name to GetSecretsStoreObjectContent

Signed-off-by: Anish Ramasekar <anish.ramasekar@gmail.com>
  • Loading branch information
aramase authored Jan 20, 2022
1 parent c06bc12 commit 9492015
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 37 deletions.
63 changes: 33 additions & 30 deletions pkg/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ type SecretFile struct {
Content []byte
Path string
FileMode int32
UID string
Version string
}

// StringArray ...
Expand Down Expand Up @@ -183,8 +185,9 @@ func (mc *mountConfig) GetServicePrincipalToken(resource string) (*adal.ServiceP
return mc.authConfig.GetServicePrincipalToken(mc.podName, mc.podNamespace, resource, mc.azureCloudEnvironment.ActiveDirectoryEndpoint, mc.tenantID, podIdentityNMIPort)
}

// MountSecretsStoreObjectContent mounts content of the secrets store object to target path
func (p *Provider) MountSecretsStoreObjectContent(ctx context.Context, attrib map[string]string, secrets map[string]string, targetPath string, defaultFilePermission os.FileMode) ([]SecretFile, map[string]string, error) {
// GetSecretsStoreObjectContent gets the objects (secret, key, certificate) from keyvault and returns the content
// to the CSI driver. The driver will write the content to the file system.
func (p *Provider) GetSecretsStoreObjectContent(ctx context.Context, attrib, secrets map[string]string, targetPath string, defaultFilePermission os.FileMode) ([]SecretFile, error) {
keyvaultName := strings.TrimSpace(attrib["keyvaultName"])
cloudName := strings.TrimSpace(attrib["cloudName"])
usePodIdentityStr := strings.TrimSpace(attrib["usePodIdentity"])
Expand All @@ -196,38 +199,38 @@ func (p *Provider) MountSecretsStoreObjectContent(ctx context.Context, attrib ma
podNamespace := strings.TrimSpace(attrib["csi.storage.k8s.io/pod.namespace"])

if keyvaultName == "" {
return nil, nil, fmt.Errorf("keyvaultName is not set")
return nil, fmt.Errorf("keyvaultName is not set")
}
if tenantID == "" {
return nil, nil, fmt.Errorf("tenantId is not set")
return nil, fmt.Errorf("tenantId is not set")
}
if len(usePodIdentityStr) == 0 {
usePodIdentityStr = "false"
}
usePodIdentity, err := strconv.ParseBool(usePodIdentityStr)
if err != nil {
return nil, nil, fmt.Errorf("failed to parse usePodIdentity flag, error: %w", err)
return nil, fmt.Errorf("failed to parse usePodIdentity flag, error: %w", err)
}
if len(useVMManagedIdentityStr) == 0 {
useVMManagedIdentityStr = "false"
}
useVMManagedIdentity, err := strconv.ParseBool(useVMManagedIdentityStr)
if err != nil {
return nil, nil, fmt.Errorf("failed to parse useVMManagedIdentity flag, error: %w", err)
return nil, fmt.Errorf("failed to parse useVMManagedIdentity flag, error: %w", err)
}

err = setAzureEnvironmentFilePath(cloudEnvFileName)
if err != nil {
return nil, nil, fmt.Errorf("failed to set AZURE_ENVIRONMENT_FILEPATH env to %s, error %w", cloudEnvFileName, err)
return nil, fmt.Errorf("failed to set AZURE_ENVIRONMENT_FILEPATH env to %s, error %w", cloudEnvFileName, err)
}
azureCloudEnv, err := ParseAzureEnvironment(cloudName)
if err != nil {
return nil, nil, fmt.Errorf("cloudName %s is not valid, error: %w", cloudName, err)
return nil, fmt.Errorf("cloudName %s is not valid, error: %w", cloudName, err)
}

authConfig, err := auth.NewConfig(usePodIdentity, useVMManagedIdentity, userAssignedIdentityID, secrets)
if err != nil {
return nil, nil, fmt.Errorf("failed to create auth config, error: %w", err)
return nil, fmt.Errorf("failed to create auth config, error: %w", err)
}

mc := &mountConfig{
Expand All @@ -241,22 +244,22 @@ func (p *Provider) MountSecretsStoreObjectContent(ctx context.Context, attrib ma

objectsStrings := attrib["objects"]
if objectsStrings == "" {
return nil, nil, fmt.Errorf("objects is not set")
return nil, fmt.Errorf("objects is not set")
}
klog.V(2).InfoS("objects string defined in secret provider class", "objects", objectsStrings, "pod", klog.ObjectRef{Namespace: podNamespace, Name: podName})

var objects StringArray
err = yaml.Unmarshal([]byte(objectsStrings), &objects)
if err != nil {
return nil, nil, fmt.Errorf("failed to yaml unmarshal objects, error: %w", err)
return nil, fmt.Errorf("failed to yaml unmarshal objects, error: %w", err)
}
klog.V(2).InfoS("unmarshaled objects yaml array", "objectsArray", objects.Array, "pod", klog.ObjectRef{Namespace: podNamespace, Name: podName})
keyVaultObjects := []KeyVaultObject{}
for i, object := range objects.Array {
var keyVaultObject KeyVaultObject
err = yaml.Unmarshal([]byte(object), &keyVaultObject)
if err != nil {
return nil, nil, fmt.Errorf("unmarshal failed for keyVaultObjects at index %d, error: %w", i, err)
return nil, fmt.Errorf("unmarshal failed for keyVaultObjects at index %d, error: %w", i, err)
}
// remove whitespace from all fields in keyVaultObject
formatKeyVaultObject(&keyVaultObject)
Expand All @@ -266,74 +269,74 @@ func (p *Provider) MountSecretsStoreObjectContent(ctx context.Context, attrib ma
klog.V(5).InfoS("unmarshaled key vault objects", "keyVaultObjects", keyVaultObjects, "count", len(keyVaultObjects), "pod", klog.ObjectRef{Namespace: podNamespace, Name: podName})

if len(keyVaultObjects) == 0 {
return nil, make(map[string]string), nil
return nil, nil
}

vaultURL, err := mc.getVaultURL()
if err != nil {
return nil, nil, errors.Wrap(err, "failed to get vault")
return nil, errors.Wrap(err, "failed to get vault")
}
klog.V(2).InfoS("vault url", "vaultName", mc.keyvaultName, "vaultURL", *vaultURL, "pod", klog.ObjectRef{Namespace: podNamespace, Name: podName})

// the keyvault name is per SPC and we don't need to recreate the client for every single keyvault object defined
kvClient, err := mc.initializeKvClient()
if err != nil {
return nil, nil, errors.Wrap(err, "failed to get keyvault client")
return nil, errors.Wrap(err, "failed to get keyvault client")
}

objectVersionMap := make(map[string]string)
files := []SecretFile{}
for _, keyVaultObject := range keyVaultObjects {
klog.V(5).InfoS("fetching object from key vault", "objectName", keyVaultObject.ObjectName, "objectType", keyVaultObject.ObjectType, "keyvault", mc.keyvaultName, "pod", klog.ObjectRef{Namespace: podNamespace, Name: podName})
if err := validateObjectFormat(keyVaultObject.ObjectFormat, keyVaultObject.ObjectType); err != nil {
return nil, nil, wrapObjectTypeError(err, keyVaultObject.ObjectType, keyVaultObject.ObjectName, keyVaultObject.ObjectVersion)
return nil, wrapObjectTypeError(err, keyVaultObject.ObjectType, keyVaultObject.ObjectName, keyVaultObject.ObjectVersion)
}
if err := validateObjectEncoding(keyVaultObject.ObjectEncoding, keyVaultObject.ObjectType); err != nil {
return nil, nil, wrapObjectTypeError(err, keyVaultObject.ObjectType, keyVaultObject.ObjectName, keyVaultObject.ObjectVersion)
return nil, wrapObjectTypeError(err, keyVaultObject.ObjectType, keyVaultObject.ObjectName, keyVaultObject.ObjectVersion)
}
fileName := keyVaultObject.ObjectName
if keyVaultObject.ObjectAlias != "" {
fileName = keyVaultObject.ObjectAlias
}
if err := validateFileName(fileName); err != nil {
return nil, nil, wrapObjectTypeError(err, keyVaultObject.ObjectType, keyVaultObject.ObjectName, keyVaultObject.ObjectVersion)
return nil, wrapObjectTypeError(err, keyVaultObject.ObjectType, keyVaultObject.ObjectName, keyVaultObject.ObjectVersion)
}

filePermission, err := validateFilePermission(keyVaultObject.FilePermission, defaultFilePermission)
if err != nil {
return nil, nil, err
return nil, err
}

// fetch the object from Key Vault
content, newObjectVersion, err := p.GetKeyVaultObjectContent(ctx, kvClient, keyVaultObject, *vaultURL)
content, newObjectVersion, err := p.getKeyVaultObjectContent(ctx, kvClient, keyVaultObject, *vaultURL)
if err != nil {
return nil, nil, err
return nil, err
}

// objectUID is a unique identifier in the format <object type>/<object name>
// This is the object id the user sees in the SecretProviderClassPodStatus
objectUID := getObjectUID(keyVaultObject.ObjectName, keyVaultObject.ObjectType)
objectVersionMap[objectUID] = newObjectVersion

objectContent, err := getContentBytes(content, keyVaultObject.ObjectType, keyVaultObject.ObjectEncoding)
if err != nil {
return nil, nil, err
return nil, err
}

// objectUID is a unique identifier in the format <object type>/<object name>
// This is the object id the user sees in the SecretProviderClassPodStatus
objectUID := getObjectUID(keyVaultObject.ObjectName, keyVaultObject.ObjectType)

// these files will be returned to the CSI driver as part of gRPC response
files = append(files, SecretFile{
Path: fileName,
Content: objectContent,
FileMode: filePermission,
UID: objectUID,
Version: newObjectVersion,
})
klog.V(5).InfoS("added file to the gRPC response", "file", fileName, "pod", klog.ObjectRef{Namespace: podNamespace, Name: podName})
}

return files, objectVersionMap, nil
return files, nil
}

// GetKeyVaultObjectContent get content of the keyvault object
func (p *Provider) GetKeyVaultObjectContent(ctx context.Context, kvClient *kv.BaseClient, kvObject KeyVaultObject, vaultURL string) (content, version string, err error) {
func (p *Provider) getKeyVaultObjectContent(ctx context.Context, kvClient *kv.BaseClient, kvObject KeyVaultObject, vaultURL string) (content, version string, err error) {
start := time.Now()
defer func() {
var errMsg string
Expand Down
4 changes: 2 additions & 2 deletions pkg/provider/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,7 @@ func TestInitializeKVClient(t *testing.T) {
}
}

func TestMountSecretsStoreObjectContent(t *testing.T) {
func TestGetSecretsStoreObjectContent(t *testing.T) {
cases := []struct {
desc string
parameters map[string]string
Expand Down Expand Up @@ -829,7 +829,7 @@ func TestMountSecretsStoreObjectContent(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "ut")
assert.NoError(t, err)

_, _, err = p.MountSecretsStoreObjectContent(context.TODO(), tc.parameters, tc.secrets, tmpDir, 0420)
_, err = p.GetSecretsStoreObjectContent(context.TODO(), tc.parameters, tc.secrets, tmpDir, 0420)
if tc.expectedErr {
assert.NotNil(t, err)
} else {
Expand Down
11 changes: 6 additions & 5 deletions pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,12 @@ func (s *CSIDriverProviderServer) Mount(ctx context.Context, req *v1alpha1.Mount
return &v1alpha1.MountResponse{}, fmt.Errorf("failed to unmarshal file permission, error: %w", err)
}

files, objectVersions, err := s.Provider.MountSecretsStoreObjectContent(ctx, attrib, secret, req.GetTargetPath(), defaultFilePermission)
files, err := s.Provider.GetSecretsStoreObjectContent(ctx, attrib, secret, req.GetTargetPath(), defaultFilePermission)
if err != nil {
klog.ErrorS(err, "failed to process mount request")
return &v1alpha1.MountResponse{}, fmt.Errorf("failed to mount objects, error: %w", err)
}
ov := []*v1alpha1.ObjectVersion{}
for k, v := range objectVersions {
ov = append(ov, &v1alpha1.ObjectVersion{Id: k, Version: v})
}

f := []*v1alpha1.File{}
// CSI driver v0.0.21+ will write to the filesystem if the files are in the response.
// No files in the response translates to "not implemented" in the CSI driver.
Expand All @@ -72,6 +68,11 @@ func (s *CSIDriverProviderServer) Mount(ctx context.Context, req *v1alpha1.Mount
Contents: file.Content,
Mode: file.FileMode,
})

ov = append(ov, &v1alpha1.ObjectVersion{
Id: file.UID,
Version: file.Version,
})
}

return &v1alpha1.MountResponse{
Expand Down

0 comments on commit 9492015

Please sign in to comment.