Skip to content

Commit

Permalink
SageMaker Deployment: Allow users to specify VPC configurations for m…
Browse files Browse the repository at this point in the history
…odel creation (mlflow#304)
  • Loading branch information
dbczumar authored and aarondav committed Aug 18, 2018
1 parent 888b0d3 commit 384001f
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 16 deletions.
64 changes: 50 additions & 14 deletions mlflow/sagemaker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def push_image_to_ecr(image=DEFAULT_IMAGE_NAME):
def deploy(app_name, model_path, execution_role_arn=None, bucket=None, run_id=None,
image_url=None, region_name="us-west-2", mode=DEPLOYMENT_MODE_CREATE, archive=False,
instance_type=DEFAULT_SAGEMAKER_INSTANCE_TYPE,
instance_count=DEFAULT_SAGEMAKER_INSTANCE_COUNT):
instance_count=DEFAULT_SAGEMAKER_INSTANCE_COUNT, vpc_config=None):
"""
Deploy model on SageMaker.
Currently active AWS account needs to have correct permissions set up.
Expand Down Expand Up @@ -211,7 +211,25 @@ def deploy(app_name, model_path, execution_role_arn=None, bucket=None, run_id=No
of supported instance types, see
https://aws.amazon.com/sagemaker/pricing/instance-types/.
:param instance_count: The number of SageMaker ML instances on which to deploy the model.
:param vpc_config: A dictionary specifying the VPC configuration to use when creating the
new SageMaker model associated with this application. The acceptable values
for this parameter are identical to those of the `VpcConfig` parameter in the
SageMaker boto3 client (https://boto3.readthedocs.io/en/latest/reference/
services/sagemaker.html#SageMaker.Client.create_model). For more information,
see https://docs.aws.amazon.com/sagemaker/latest/dg/API_VpcConfig.html.
Example:
>>> import mlflow.sagemaker as mfs
>>> vpc_config = {
... 'SecurityGroupIds': [
... 'sg-123456abc',
... ],
... 'Subnets': [
... 'subnet-123456abc',
... ]
... }
>>> mfs.deploy(..., vpc_config=vpc_config)
"""
if mode not in DEPLOYMENT_MODES:
raise ValueError("`mode` must be one of: {mds}".format(
Expand Down Expand Up @@ -244,7 +262,8 @@ def deploy(app_name, model_path, execution_role_arn=None, bucket=None, run_id=No
mode=mode,
archive=archive,
instance_type=instance_type,
instance_count=instance_count)
instance_count=instance_count,
vpc_config=vpc_config)


def delete(app_name, region_name="us-west-2", archive=False):
Expand Down Expand Up @@ -418,7 +437,7 @@ def _upload_s3(local_model_path, bucket, prefix):


def _deploy(role, image_url, app_name, model_s3_path, run_id, region_name, mode, archive,
instance_type, instance_count):
instance_type, instance_count, vpc_config):
"""
Deploy model on sagemaker.
:param role: SageMaker execution ARN role
Expand All @@ -432,6 +451,8 @@ def _deploy(role, image_url, app_name, model_s3_path, run_id, region_name, mode,
will be preserved. If False, these resources will be deleted.
:param instance_type: The type of SageMaker ML instance on which to deploy the model.
:param instance_count: The number of SageMaker ML instances on which to deploy the model.
:param vpc_config: A dictionary specifying the VPC configuration to use when creating the
new SageMaker model associated with this application.
"""
sage_client = boto3.client('sagemaker', region_name=region_name)
s3_client = boto3.client('s3', region_name=region_name)
Expand Down Expand Up @@ -464,6 +485,7 @@ def _deploy(role, image_url, app_name, model_s3_path, run_id, region_name, mode,
run_id=run_id,
instance_type=instance_type,
instance_count=instance_count,
vpc_config=vpc_config,
mode=mode,
archive=archive,
role=role,
Expand All @@ -476,6 +498,7 @@ def _deploy(role, image_url, app_name, model_s3_path, run_id, region_name, mode,
run_id=run_id,
instance_type=instance_type,
instance_count=instance_count,
vpc_config=vpc_config,
role=role,
sage_client=sage_client)

Expand Down Expand Up @@ -509,13 +532,15 @@ def _get_sagemaker_config_name(endpoint_name):


def _create_sagemaker_endpoint(endpoint_name, image_url, model_s3_path, run_id, instance_type,
instance_count, role, sage_client):
vpc_config, instance_count, role, sage_client):
"""
:param image_url: URL of the ECR-hosted docker image the model is being deployed into.
:param model_s3_path: S3 path where we stored the model artifacts.
:param run_id: Run ID that generated this model.
:param instance_type: The type of SageMaker ML instance on which to deploy the model.
:param instance_count: The number of SageMaker ML instances on which to deploy the model.
:param vpc_config: A dictionary specifying the VPC configuration to use when creating the
new SageMaker model associated with this SageMaker endpoint.
:param role: SageMaker execution ARN role
:param sage_client: A boto3 client for SageMaker
"""
Expand All @@ -525,6 +550,7 @@ def _create_sagemaker_endpoint(endpoint_name, image_url, model_s3_path, run_id,
model_name = _get_sagemaker_model_name(endpoint_name)
model_response = _create_sagemaker_model(model_name=model_name,
model_s3_path=model_s3_path,
vpc_config=vpc_config,
run_id=run_id,
image_url=image_url,
execution_role=role,
Expand Down Expand Up @@ -561,13 +587,16 @@ def _create_sagemaker_endpoint(endpoint_name, image_url, model_s3_path, run_id,


def _update_sagemaker_endpoint(endpoint_name, image_url, model_s3_path, run_id, instance_type,
instance_count, mode, archive, role, sage_client, s3_client):
instance_count, vpc_config, mode, archive, role, sage_client,
s3_client):
"""
:param image_url: URL of the ECR-hosted Docker image the model is being deployed into
:param model_s3_path: S3 path where we stored the model artifacts
:param run_id: Run ID that generated this model
:param instance_type: The type of SageMaker ML instance on which to deploy the model.
:param instance_count: The number of SageMaker ML instances on which to deploy the model.
:param vpc_config: A dictionary specifying the VPC configuration to use when creating the
new SageMaker model associated with this SageMaker endpoint.
:param mode: either mlflow.sagemaker.DEPLOYMENT_MODE_ADD or
mlflow.sagemaker.DEPLOYMENT_MODE_REPLACE.
:param archive: If True, any pre-existing SageMaker application resources that become inactive
Expand Down Expand Up @@ -596,6 +625,7 @@ def _update_sagemaker_endpoint(endpoint_name, image_url, model_s3_path, run_id,
new_model_name = _get_sagemaker_model_name(endpoint_name)
new_model_response = _create_sagemaker_model(model_name=new_model_name,
model_s3_path=model_s3_path,
vpc_config=vpc_config,
run_id=run_id,
image_url=image_url,
execution_role=role,
Expand Down Expand Up @@ -656,28 +686,34 @@ def _update_sagemaker_endpoint(endpoint_name, image_url, model_s3_path, run_id,
carn=deployed_config_arn))


def _create_sagemaker_model(model_name, model_s3_path, run_id, image_url, execution_role,
sage_client):
def _create_sagemaker_model(model_name, model_s3_path, vpc_config, run_id, image_url,
execution_role, sage_client):
"""
:param model_s3_path: S3 path where the model artifacts are stored
:param vpc_config: A dictionary specifying the VPC configuration to use when creating the
new SageMaker model associated with this SageMaker endpoint.
:param run_id: Run ID that generated this model
:param image_url: URL of the ECR-hosted Docker image that will serve as the
model's container
:param execution_role: The ARN of the role that SageMaker will assume when creating the model
:param sage_client: A boto3 client for SageMaker
:return: AWS response containing metadata associated with the new model
"""
model_response = sage_client.create_model(
ModelName=model_name,
PrimaryContainer={
create_model_args = {
"ModelName": model_name,
"PrimaryContainer": {
'ContainerHostname': 'mfs-%s' % model_name,
'Image': image_url,
'ModelDataUrl': model_s3_path,
'Environment': {},
},
ExecutionRoleArn=execution_role,
Tags=[{'Key': 'run_id', 'Value': str(run_id)}, ],
)
"ExecutionRoleArn": execution_role,
"Tags": [{'Key': 'run_id', 'Value': str(run_id)}],
}
if vpc_config is not None:
create_model_args["VpcConfig"] = vpc_config

model_response = sage_client.create_model(**create_model_args)
return model_response


Expand Down
14 changes: 12 additions & 2 deletions mlflow/sagemaker/cli.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import print_function

import os
import json

import click

Expand Down Expand Up @@ -37,17 +38,26 @@ def commands():
" https://aws.amazon.com/sagemaker/pricing/instance-types/.")
@click.option("--instance-count", "-c", default=mlflow.sagemaker.DEFAULT_SAGEMAKER_INSTANCE_COUNT,
help="The number of SageMaker ML instances on which to deploy the model")
@click.option("--vpc-config", "-v",
help="Path to a file containing a JSON-formatted VPC configuration. This"
" configuration will be used when creating the new SageMaker model associated"
" with this application. For more information, see"
" https://docs.aws.amazon.com/sagemaker/latest/dg/API_VpcConfig.html")
def deploy(app_name, model_path, execution_role_arn, bucket, run_id, image_url, region_name, mode,
archive, instance_type, instance_count):
archive, instance_type, instance_count, vpc_config):
"""
Deploy model on Sagemaker as a REST API endpoint. Current active AWS account needs to have
correct permissions setup.
"""
if vpc_config is not None:
with open(vpc_config, "r") as f:
vpc_config = json.load(f)

mlflow.sagemaker.deploy(app_name=app_name, model_path=model_path,
execution_role_arn=execution_role_arn, bucket=bucket, run_id=run_id,
image_url=image_url, region_name=region_name, mode=mode,
archive=archive, instance_type=instance_type,
instance_count=instance_count)
instance_count=instance_count, vpc_config=vpc_config)


@commands.command("delete")
Expand Down

0 comments on commit 384001f

Please sign in to comment.