Skip to content

Commit

Permalink
feat(start-api) Support stage name and stage variables (aws#1203)
Browse files Browse the repository at this point in the history
  • Loading branch information
viksrivat authored and jfuss committed Jun 24, 2019
1 parent 81e7e13 commit 513fe03
Show file tree
Hide file tree
Showing 11 changed files with 509 additions and 136 deletions.
14 changes: 7 additions & 7 deletions samcli/commands/local/lib/local_api_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,9 @@ def _make_routing_list(api_provider):
routes = []
for api in api_provider.get_all():
route = Route(methods=[api.method], function_name=api.function_name, path=api.path,
binary_types=api.binary_media_types)
binary_types=api.binary_media_types, stage_name=api.stage_name,
stage_variables=api.stage_variables)
routes.append(route)

return routes

@staticmethod
Expand Down Expand Up @@ -139,11 +139,11 @@ def _print_routes(api_provider, host, port):
for _, config in grouped_api_configs.items():
methods_str = "[{}]".format(', '.join(config["methods"]))
output = "Mounting {} at http://{}:{}{} {}".format(
config["function_name"],
host,
port,
config["path"],
methods_str)
config["function_name"],
host,
port,
config["path"],
methods_str)
print_lines.append(output)

LOG.info(output)
Expand Down
12 changes: 9 additions & 3 deletions samcli/commands/local/lib/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,10 +216,16 @@ def get_all(self):
"cors",

# List(Str). List of the binary media types the API
"binary_media_types"
"binary_media_types",
# The Api stage name
"stage_name",
# The variables for that stage
"stage_variables"
])
_ApiTuple.__new__.__defaults__ = (None, # Cors is optional and defaults to None
[] # binary_media_types is optional and defaults to empty
_ApiTuple.__new__.__defaults__ = (None, # Cors is optional and defaults to None
[], # binary_media_types is optional and defaults to empty,
None, # Stage name is optional with default None
None # Stage variables is optional with default None
)


Expand Down
67 changes: 63 additions & 4 deletions samcli/commands/local/lib/sam_api_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@


class SamApiProvider(ApiProvider):

_IMPLICIT_API_RESOURCE_ID = "ServerlessRestApi"
_SERVERLESS_FUNCTION = "AWS::Serverless::Function"
_SERVERLESS_API = "AWS::Serverless::Api"
Expand Down Expand Up @@ -127,6 +126,8 @@ def _extract_from_serverless_api(self, logical_id, api_resource, collector):
body = properties.get("DefinitionBody")
uri = properties.get("DefinitionUri")
binary_media = properties.get("BinaryMediaTypes", [])
stage_name = properties.get("StageName")
stage_variables = properties.get("Variables")

if not body and not uri:
# Swagger is not found anywhere.
Expand All @@ -146,6 +147,9 @@ def _extract_from_serverless_api(self, logical_id, api_resource, collector):
collector.add_binary_media_types(logical_id, parser.get_binary_media_types()) # Binary media from swagger
collector.add_binary_media_types(logical_id, binary_media) # Binary media specified on resource in template

collector.add_stage_name(logical_id, stage_name)
collector.add_stage_variables(logical_id, stage_variables)

@staticmethod
def _merge_apis(collector):
"""
Expand Down Expand Up @@ -324,7 +328,7 @@ class ApiCollector(object):
# This is intentional because it allows us to easily extend this class to support future properties on the API.
# We will store properties of Implicit APIs also in this format which converges the handling of implicit & explicit
# APIs.
Properties = namedtuple("Properties", ["apis", "binary_media_types", "cors"])
Properties = namedtuple("Properties", ["apis", "binary_media_types", "cors", "stage_name", "stage_variables"])

def __init__(self):
# API properties stored per resource. Key is the LogicalId of the AWS::Serverless::Api resource and
Expand Down Expand Up @@ -387,6 +391,40 @@ def add_binary_media_types(self, logical_id, binary_media_types):
else:
LOG.debug("Unsupported data type of binary media type value of resource '%s'", logical_id)

def add_stage_name(self, logical_id, stage_name):
"""
Stores the stage name for the API with the given local ID
Parameters
----------
logical_id : str
LogicalId of the AWS::Serverless::Api resource
stage_name : str
The stage_name string
"""
properties = self._get_properties(logical_id)
properties = properties._replace(stage_name=stage_name)
self._set_properties(logical_id, properties)

def add_stage_variables(self, logical_id, stage_variables):
"""
Stores the stage variables for the API with the given local ID
Parameters
----------
logical_id : str
LogicalId of the AWS::Serverless::Api resource
stage_variables : dict
A dictionary containing stage variables.
"""
properties = self._get_properties(logical_id)
properties = properties._replace(stage_variables=stage_variables)
self._set_properties(logical_id, properties)

def _get_apis_with_config(self, logical_id):
"""
Returns the list of APIs in this resource along with other extra configuration such as binary media types,
Expand All @@ -410,12 +448,16 @@ def _get_apis_with_config(self, logical_id):
# These configs need to be applied to each API
binary_media = sorted(list(properties.binary_media_types)) # Also sort the list to keep the ordering stable
cors = properties.cors
stage_name = properties.stage_name
stage_variables = properties.stage_variables

result = []
for api in properties.apis:
# Create a copy of the API with updated configuration
updated_api = api._replace(binary_media_types=binary_media,
cors=cors)
cors=cors,
stage_name=stage_name,
stage_variables=stage_variables)
result.append(updated_api)

return result
Expand All @@ -440,10 +482,27 @@ def _get_properties(self, logical_id):
self.by_resource[logical_id] = self.Properties(apis=[],
# Use a set() to be able to easily de-dupe
binary_media_types=set(),
cors=None)
cors=None,
stage_name=None,
stage_variables=None)

return self.by_resource[logical_id]

def _set_properties(self, logical_id, properties):
"""
Sets the properties of resource with given logical ID. If a resource is not found, it does nothing
Parameters
----------
logical_id : str
Logical ID of the resource
properties : samcli.commands.local.lib.sam_api_provider.ApiCollector.Properties
Properties object for this resource.
"""

if logical_id in self.by_resource:
self.by_resource[logical_id] = properties

@staticmethod
def _normalize_binary_media_type(value):
"""
Expand Down
1 change: 0 additions & 1 deletion samcli/commands/local/lib/swagger/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@


class SwaggerParser(object):

_INTEGRATION_KEY = "x-amazon-apigateway-integration"
_ANY_METHOD_EXTENSION_KEY = "x-amazon-apigateway-any-method"
_BINARY_MEDIA_TYPES_EXTENSION_KEY = "x-amazon-apigateway-binary-media-types" # pylint: disable=C0103
Expand Down
16 changes: 10 additions & 6 deletions samcli/local/apigw/local_apigw_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

class Route(object):

def __init__(self, methods, function_name, path, binary_types=None):
def __init__(self, methods, function_name, path, binary_types=None, stage_name=None, stage_variables=None):
"""
Creates an ApiGatewayRoute
Expand All @@ -31,10 +31,11 @@ def __init__(self, methods, function_name, path, binary_types=None):
self.function_name = function_name
self.path = path
self.binary_types = binary_types or []
self.stage_name = stage_name
self.stage_variables = stage_variables


class LocalApigwService(BaseLocalService):

_DEFAULT_PORT = 3000
_DEFAULT_HOST = '127.0.0.1'

Expand Down Expand Up @@ -143,7 +144,8 @@ def _request_handler(self, **kwargs):
route = self._get_current_route(request)

try:
event = self._construct_event(request, self.port, route.binary_types)
event = self._construct_event(request, self.port, route.binary_types, route.stage_name,
route.stage_variables)
except UnicodeDecodeError:
return ServiceErrorResponses.lambda_failure_response()

Expand Down Expand Up @@ -313,13 +315,14 @@ def _merge_response_headers(headers, multi_headers):
return processed_headers

@staticmethod
def _construct_event(flask_request, port, binary_types):
def _construct_event(flask_request, port, binary_types, stage_name=None, stage_variables=None):
"""
Helper method that constructs the Event to be passed to Lambda
:param request flask_request: Flask Request
:return: String representing the event
"""
# pylint: disable-msg=too-many-locals

identity = ContextIdentity(source_ip=flask_request.remote_addr)

Expand All @@ -342,7 +345,7 @@ def _construct_event(flask_request, port, binary_types):

context = RequestContext(resource_path=endpoint,
http_method=method,
stage="prod",
stage=stage_name,
identity=identity,
path=endpoint)

Expand All @@ -360,7 +363,8 @@ def _construct_event(flask_request, port, binary_types):
multi_value_headers=multi_value_headers_dict,
path_parameters=flask_request.view_args,
path=flask_request.path,
is_base_64_encoded=is_base_64)
is_base_64_encoded=is_base_64,
stage_variables=stage_variables)

event_str = json.dumps(event.to_dict())
LOG.debug("Constructed String representation of Event to invoke Lambda. Event: %s", event_str)
Expand Down
54 changes: 54 additions & 0 deletions tests/integration/local/start_api/test_start_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,3 +496,57 @@ def test_forward_headers_are_added_to_event(self):
self.assertEquals(response_data.get("multiValueHeaders").get("X-Forwarded-Proto"), ["http"])
self.assertEquals(response_data.get("headers").get("X-Forwarded-Port"), self.port)
self.assertEquals(response_data.get("multiValueHeaders").get("X-Forwarded-Port"), [self.port])


class TestStartApiWithStage(StartApiIntegBaseClass):
"""
Test Class centered around the different responses that can happen in Lambda and pass through start-api
"""
template_path = "/testdata/start_api/template.yaml"

def setUp(self):
self.url = "http://127.0.0.1:{}".format(self.port)

def test_default_stage_name(self):
response = requests.get(self.url + "/echoeventbody")

self.assertEquals(response.status_code, 200)

response_data = response.json()

self.assertEquals(response_data.get("requestContext", {}).get("stage"), "Prod")

def test_global_stage_variables(self):
response = requests.get(self.url + "/echoeventbody")

self.assertEquals(response.status_code, 200)

response_data = response.json()

self.assertEquals(response_data.get("stageVariables"), {'VarName': 'varValue'})


class TestStartApiWithStageAndSwagger(StartApiIntegBaseClass):
"""
Test Class centered around the different responses that can happen in Lambda and pass through start-api
"""
template_path = "/testdata/start_api/swagger-template.yaml"

def setUp(self):
self.url = "http://127.0.0.1:{}".format(self.port)

def test_swagger_stage_name(self):
response = requests.get(self.url + "/echoeventbody")

self.assertEquals(response.status_code, 200)

response_data = response.json()
self.assertEquals(response_data.get("requestContext", {}).get("stage"), "dev")

def test_swagger_stage_variable(self):
response = requests.get(self.url + "/echoeventbody")

self.assertEquals(response.status_code, 200)

response_data = response.json()
self.assertEquals(response_data.get("stageVariables"), {'VarName': 'varValue'})
28 changes: 27 additions & 1 deletion tests/integration/testdata/start_api/swagger-template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ Resources:
MyApi:
Type: AWS::Serverless::Api
Properties:
StageName: prod
StageName: dev
Variables:
VarName: varValue
DefinitionBody:
swagger: "2.0"
info:
Expand Down Expand Up @@ -67,6 +69,15 @@ Resources:
uri:
Fn::Sub: arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${EchoBase64EventBodyFunction.Arn}/invocations

"/echoeventbody":
post:
x-amazon-apigateway-integration:
httpMethod: POST
type: aws_proxy
uri:
Fn::Sub: arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${EchoEventHandlerFunction.Arn}/invocations


MyLambdaFunction:
Type: AWS::Serverless::Function
Properties:
Expand Down Expand Up @@ -109,3 +120,18 @@ Resources:
Handler: main.echo_base64_event_body
Runtime: python3.6
CodeUri: .

EchoEventHandlerFunction:
Type: AWS::Serverless::Function
Properties:
Handler: main.echo_event_handler
Runtime: python3.6
CodeUri: .
Events:
GetApi:
Type: Api
Properties:
Path: /{proxy+}
Method: GET
RestApiId:
Ref: MyApi
3 changes: 2 additions & 1 deletion tests/integration/testdata/start_api/template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ Globals:
# These are equivalent to image/gif and image/png when deployed
- image~1gif
- image~1png

Variables:
VarName: varValue
Resources:
HelloWorldFunction:
Type: AWS::Serverless::Function
Expand Down
Loading

0 comments on commit 513fe03

Please sign in to comment.