Skip to content

Commit

Permalink
Workspaces Web: tagging support (#8217)
Browse files Browse the repository at this point in the history
  • Loading branch information
jamarcelin authored Oct 10, 2024
1 parent 5cbfacf commit 5b7010f
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 22 deletions.
53 changes: 31 additions & 22 deletions moto/workspacesweb/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from moto.core.common_models import BaseModel
from moto.utilities.utils import get_partition

from ..utilities.tagging_service import TaggingService


class FakeUserSettings(BaseModel):
def __init__(
Expand All @@ -23,7 +25,6 @@ def __init__(
idle_disconnect_timeout_in_minutes: int,
paste_allowed: bool,
print_allowed: bool,
tags: Dict[str, str],
upload_allowed: bool,
region_name: str,
account_id: str,
Expand All @@ -41,7 +42,6 @@ def __init__(
self.idle_disconnect_timeout_in_minutes = idle_disconnect_timeout_in_minutes
self.paste_allowed = paste_allowed if paste_allowed else "Disabled"
self.print_allowed = print_allowed if print_allowed else "Disabled"
self.tags = tags
self.upload_allowed = upload_allowed if upload_allowed else "Disabled"
self.associated_portal_arns: List[str] = []

Expand All @@ -62,7 +62,6 @@ def to_dict(self) -> Dict[str, Any]:
"idleDisconnectTimeoutInMinutes": self.idle_disconnect_timeout_in_minutes,
"pasteAllowed": self.paste_allowed,
"printAllowed": self.print_allowed,
"tags": self.tags,
"uploadAllowed": self.upload_allowed,
"userSettingsArn": self.arn,
}
Expand All @@ -73,7 +72,6 @@ def __init__(
self,
client_token: str,
kinesis_stream_arn: str,
tags: Dict[str, str],
region_name: str,
account_id: str,
):
Expand All @@ -83,7 +81,6 @@ def __init__(
)
self.client_token = client_token
self.kinesis_stream_arn = kinesis_stream_arn
self.tags = tags
self.associated_portal_arns: List[str] = []

def arn_formatter(self, _id: str, account_id: str, region_name: str) -> str:
Expand All @@ -93,7 +90,6 @@ def to_dict(self) -> Dict[str, Any]:
return {
"associatedPortalArns": self.associated_portal_arns,
"kinesisStreamArn": self.kinesis_stream_arn,
"tags": self.tags,
"userAccessLoggingSettingsArn": self.arn,
}

Expand All @@ -103,7 +99,6 @@ def __init__(
self,
security_group_ids: List[str],
subnet_ids: List[str],
tags: Dict[str, str],
vpc_id: str,
region_name: str,
account_id: str,
Expand All @@ -112,7 +107,6 @@ def __init__(
self.arn = self.arn_formatter(self.network_settings_id, account_id, region_name)
self.security_group_ids = security_group_ids
self.subnet_ids = subnet_ids
self.tags = tags
self.vpc_id = vpc_id
self.associated_portal_arns: List[str] = []

Expand All @@ -125,7 +119,6 @@ def to_dict(self) -> Dict[str, Any]:
"networkSettingsArn": self.arn,
"securityGroupIds": self.security_group_ids,
"subnetIds": self.subnet_ids,
"Tags": self.tags,
"vpcId": self.vpc_id,
}

Expand All @@ -137,7 +130,6 @@ def __init__(
browser_policy: str,
client_token: str,
customer_managed_key: str,
tags: Dict[str, str],
region_name: str,
account_id: str,
):
Expand All @@ -147,7 +139,6 @@ def __init__(
self.browser_policy = browser_policy
self.client_token = client_token
self.customer_managed_key = customer_managed_key
self.tags = tags
self.associated_portal_arns: List[str] = []

def arn_formatter(self, _id: str, account_id: str, region_name: str) -> str:
Expand All @@ -160,7 +151,6 @@ def to_dict(self) -> Dict[str, Any]:
"additionalEncryptionContext": self.additional_encryption_context,
"browserPolicy": self.browser_policy,
"customerManagedKey": self.customer_managed_key,
"tags": self.tags,
}


Expand All @@ -174,7 +164,6 @@ def __init__(
display_name: str,
instance_type: str,
max_concurrent_sessions: str,
tags: Dict[str, str],
region_name: str,
account_id: str,
):
Expand All @@ -187,7 +176,6 @@ def __init__(
self.display_name = display_name
self.instance_type = instance_type
self.max_concurrent_sessions = max_concurrent_sessions
self.tags = tags
self.portal_endpoint = f"{self.portal_id}.portal.aws"
self.browser_type = "Chrome"
self.creation_time = datetime.datetime.now().isoformat()
Expand Down Expand Up @@ -225,7 +213,6 @@ def to_dict(self) -> Dict[str, Any]:
"trustStoreArn": self.trust_store_arn,
"userAccessLoggingSettingsArn": self.user_access_logging_settings_arn,
"userSettingsArn": self.user_settings_arn,
"tags": self.tags,
}


Expand All @@ -239,6 +226,7 @@ def __init__(self, region_name: str, account_id: str):
self.user_settings: Dict[str, FakeUserSettings] = {}
self.user_access_logging_settings: Dict[str, FakeUserAccessLoggingSettings] = {}
self.portals: Dict[str, FakePortal] = {}
self.tagger = TaggingService()

def create_network_settings(
self,
Expand All @@ -250,12 +238,13 @@ def create_network_settings(
network_settings_object = FakeNetworkSettings(
security_group_ids,
subnet_ids,
tags,
vpc_id,
self.region_name,
self.account_id,
)
self.network_settings[network_settings_object.arn] = network_settings_object
if tags:
self.tag_resource("TEMP_CLIENT_TOKEN", network_settings_object.arn, tags)
return network_settings_object.arn

def list_network_settings(self) -> List[Dict[str, str]]:
Expand All @@ -276,18 +265,19 @@ def create_browser_settings(
browser_policy: str,
client_token: str,
customer_managed_key: str,
tags: Dict[str, str],
tags: Optional[List[Dict[str, str]]] = None,
) -> str:
browser_settings_object = FakeBrowserSettings(
additional_encryption_context,
browser_policy,
client_token,
customer_managed_key,
tags,
self.region_name,
self.account_id,
)
self.browser_settings[browser_settings_object.arn] = browser_settings_object
if tags:
self.tag_resource(client_token, browser_settings_object.arn, tags)
return browser_settings_object.arn

def list_browser_settings(self) -> List[Dict[str, str]]:
Expand All @@ -311,7 +301,7 @@ def create_portal(
display_name: str,
instance_type: str,
max_concurrent_sessions: str,
tags: Dict[str, str],
tags: Optional[List[Dict[str, str]]] = None,
) -> Tuple[str, str]:
portal_object = FakePortal(
additional_encryption_context,
Expand All @@ -321,11 +311,12 @@ def create_portal(
display_name,
instance_type,
max_concurrent_sessions,
tags,
self.region_name,
self.account_id,
)
self.portals[portal_object.arn] = portal_object
if tags:
self.tag_resource(client_token, portal_object.arn, tags)
return portal_object.arn, portal_object.portal_endpoint

def list_portals(self) -> List[Dict[str, Any]]:
Expand Down Expand Up @@ -405,12 +396,13 @@ def create_user_settings(
idle_disconnect_timeout_in_minutes,
paste_allowed,
print_allowed,
tags,
upload_allowed,
self.region_name,
self.account_id,
)
self.user_settings[user_settings_object.arn] = user_settings_object
if tags:
self.tag_resource(client_token, user_settings_object.arn, tags)
return user_settings_object.arn

def get_user_settings(self, user_settings_arn: str) -> Dict[str, Any]:
Expand All @@ -423,11 +415,15 @@ def create_user_access_logging_settings(
self, client_token: Any, kinesis_stream_arn: Any, tags: Any
) -> str:
user_access_logging_settings_object = FakeUserAccessLoggingSettings(
client_token, kinesis_stream_arn, tags, self.region_name, self.account_id
client_token, kinesis_stream_arn, self.region_name, self.account_id
)
self.user_access_logging_settings[user_access_logging_settings_object.arn] = (
user_access_logging_settings_object
)
if tags:
self.tag_resource(
client_token, user_access_logging_settings_object.arn, tags
)
return user_access_logging_settings_object.arn

def get_user_access_logging_settings(
Expand Down Expand Up @@ -476,5 +472,18 @@ def list_user_access_logging_settings(self) -> List[Dict[str, str]]:
for user_access_logging_settings in self.user_access_logging_settings.values()
]

def tag_resource(self, client_token: str, resource_arn: str, tags: Any) -> None:
self.tagger.tag_resource(resource_arn, tags)

def untag_resource(self, resource_arn: str, tag_keys: Any) -> None:
self.tagger.untag_resource_using_names(resource_arn, tag_keys)

def list_tags_for_resource(self, resource_arn: str) -> List[Dict[str, str]]:
tags = self.tagger.get_tag_dict_for_resource(resource_arn)
Tags = []
for key, value in tags.items():
Tags.append({"Key": key, "Value": value})
return Tags


workspacesweb_backends = BackendDict(WorkSpacesWebBackend, "workspaces-web")
27 changes: 27 additions & 0 deletions moto/workspacesweb/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,3 +338,30 @@ def list_user_access_logging_settings(self) -> str:
self.workspacesweb_backend.list_user_access_logging_settings()
)
return json.dumps(dict(userAccessLoggingSettings=user_access_logging_settings))

def tag_resource(self) -> str:
client_token = self._get_param("clientToken")
resource_arn = unquote(self._get_param("resourceArn"))
tags = self._get_param("tags")
self.workspacesweb_backend.tag_resource(
client_token=client_token,
resource_arn=resource_arn,
tags=tags,
)
return json.dumps(dict())

def untag_resource(self) -> str:
tagKeys = self.__dict__["data"]["tagKeys"]
resource_arn = unquote(self._get_param("resourceArn"))
self.workspacesweb_backend.untag_resource(
resource_arn=resource_arn,
tag_keys=tagKeys,
)
return json.dumps(dict())

def list_tags_for_resource(self) -> str:
resource_arn = unquote(self.parsed_url.path.split("/tags/")[-1])
tags = self.workspacesweb_backend.list_tags_for_resource(
resource_arn=resource_arn,
)
return json.dumps(dict(tags=tags))
1 change: 1 addition & 0 deletions moto/workspacesweb/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@
"{0}/portals/(?P<portalArn>.*)/networkSettings$": WorkSpacesWebResponse.dispatch,
"{0}/portals/(?P<portalArn>.*)/userSettings$": WorkSpacesWebResponse.dispatch,
"{0}/portals/(?P<portalArn>.*)/userAccessLoggingSettings$": WorkSpacesWebResponse.dispatch,
"{0}/tags/(?P<resourceArn>.+)$": WorkSpacesWebResponse.dispatch,
}
78 changes: 78 additions & 0 deletions tests/test_workspacesweb/test_workspacesweb.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
FAKE_SUBNET_IDS = ["subnet-0123456789abcdef0", "subnet-abcdef0123456789"]
FAKE_TAGS = [
{"Key": "TestKey", "Value": "TestValue"},
{"Key": "TestKey2", "Value": "TestValue2"},
]
FAKE_VPC_ID = "vpc-0123456789abcdef0"
FAKE_KMS_KEY_ID = "abcd1234-5678-90ab-cdef-FAKEKEY"
Expand Down Expand Up @@ -436,3 +437,80 @@ def test_associate_user_access_logging_settings():
userAccessLoggingSettingsArn=user_access_logging_settings_arn
)["userAccessLoggingSettings"]
assert resp["associatedPortalArns"] == [portal_arn]


@mock_aws
def test_list_tags_for_resource():
client = boto3.client("workspaces-web", region_name="ap-southeast-1")

OTHER_FAKE_TAGS = [
{"Key": "FAKEKEY1", "Value": "FAKEVALUE1"},
{"Key": "FAKEKEY2", "Value": "FAKEVALUE2"},
]

arns = []

arns.append(
client.create_portal(
additionalEncryptionContext={"Key1": "Encryption", "Key2": "Context"},
authenticationType="Standard",
clientToken="TestClient",
customerManagedKey=FAKE_KMS_KEY_ID,
displayName="TestDisplayName",
instanceType="TestInstanceType",
maxConcurrentSessions=5,
tags=FAKE_TAGS,
)["portalArn"]
)

arns.append(
client.create_network_settings(
securityGroupIds=FAKE_SECURITY_GROUP_IDS,
subnetIds=FAKE_SUBNET_IDS,
tags=FAKE_TAGS,
vpcId=FAKE_VPC_ID,
)["networkSettingsArn"]
)

arns.append(
client.create_user_settings(
copyAllowed="Disabled",
pasteAllowed="Disabled",
printAllowed="Disabled",
uploadAllowed="Disabled",
downloadAllowed="Disabled",
tags=FAKE_TAGS,
)["userSettingsArn"]
)

arns.append(
client.create_user_access_logging_settings(
kinesisStreamArn="arn:aws:kinesis:ap-southeast-1:123456789012:stream/TestStream",
tags=FAKE_TAGS,
)["userAccessLoggingSettingsArn"]
)

arns.append(
client.create_browser_settings(
additionalEncryptionContext={"Key1": "Value1", "Key2": "Value2"},
browserPolicy="TestBrowserPolicy",
clientToken="TestClient",
customerManagedKey=FAKE_KMS_KEY_ID,
tags=FAKE_TAGS,
)["browserSettingsArn"]
)

for arn in arns:
resp = client.list_tags_for_resource(resourceArn=arn)
assert resp["tags"] == FAKE_TAGS

client.tag_resource(resourceArn=arn, tags=OTHER_FAKE_TAGS)
resp = client.list_tags_for_resource(resourceArn=arn)
assert resp["tags"] == FAKE_TAGS + OTHER_FAKE_TAGS

client.untag_resource(resourceArn=arn, tagKeys=["FAKEKEY1", "TestKey"])
resp = client.list_tags_for_resource(resourceArn=arn)
assert resp["tags"] == [
{"Key": "TestKey2", "Value": "TestValue2"},
{"Key": "FAKEKEY2", "Value": "FAKEVALUE2"},
]

0 comments on commit 5b7010f

Please sign in to comment.