Skip to content

Commit

Permalink
Allow users to add aws secrets
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeffwan committed Apr 12, 2019
1 parent 39378a6 commit 8d109fd
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 0 deletions.
59 changes: 59 additions & 0 deletions sdk/python/kfp/aws.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright 2019 The Kubeflow Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

def use_aws_secret(secret_name='aws-secret', aws_access_key_id_name='AWS_ACCESS_KEY_ID', aws_secret_access_key_name='AWS_SECRET_ACCESS_KEY'):
"""An operator that configures the container to use AWS credentials.
AWS doesn't create secret along with kubeflow deployment and it requires users
to manually create credential secret with proper permissions.
---
apiVersion: v1
kind: Secret
metadata:
name: aws-secret
type: Opaque
data:
AWS_ACCESS_KEY_ID: BASE64_YOUR_AWS_ACCESS_KEY_ID
AWS_SECRET_ACCESS_KEY: BASE64_YOUR_AWS_SECRET_ACCESS_KEY
"""

def _use_aws_secret(task):
from kubernetes import client as k8s_client
return (
task
.add_env_variable(
k8s_client.V1EnvVar(
name='AWS_ACCESS_KEY_ID',
value_from=k8s_client.V1EnvVarSource(
secret_key_ref=k8s_client.V1SecretKeySelector(
name=secret_name,
key=aws_access_key_id_name
)
)
)
)
.add_env_variable(
k8s_client.V1EnvVar(
name='AWS_SECRET_ACCESS_KEY',
value_from=k8s_client.V1EnvVarSource(
secret_key_ref=k8s_client.V1SecretKeySelector(
name=secret_name,
key=aws_secret_access_key_name
)
)
)
)
)

return _use_aws_secret
42 changes: 42 additions & 0 deletions sdk/python/tests/dsl/aws_extensions_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright 2019 The Kubeflow Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from kfp.dsl import Pipeline, ContainerOp
from kfp.aws import use_aws_secret
import unittest
import inspect

class AwsExtensionTests(unittest.TestCase):
def test_default_aws_secret_name(self):
spec = inspect.getfullargspec(use_aws_secret)
assert len(spec.defaults) == 3
assert spec.defaults[0] == 'aws-secret'
assert spec.defaults[1] == 'AWS_ACCESS_KEY_ID'
assert spec.defaults[2] == 'AWS_SECRET_ACCESS_KEY'

def test_use_aws_secret(self):
with Pipeline('somename') as p:
op1 = ContainerOp(name='op1', image='image')
op1 = op1.apply(use_aws_secret('myaws-secret', 'key_id', 'access_key'))
assert len(op1.env_variables) == 2

index = 0
for expected in ['key_id', 'access_key']:
assert op1.env_variables[index].name == expected
assert op1.env_variables[index].value_from.secret_key_ref.name == 'myaws-secret'
assert op1.env_variables[index].value_from.secret_key_ref.key == expected
index += 1

if __name__ == '__main__':
unittest.main()

0 comments on commit 8d109fd

Please sign in to comment.