diff --git a/samtranslator/model/api/websocket_api_generator.py b/samtranslator/model/api/websocket_api_generator.py index ff3d249c2..9fe71e0fd 100644 --- a/samtranslator/model/api/websocket_api_generator.py +++ b/samtranslator/model/api/websocket_api_generator.py @@ -32,7 +32,7 @@ class WebSocketApiGenerator(ApiV2Generator): def __init__( # noqa: PLR0913 self, logical_id: str, - stage_name: str | None, + stage_name: Intrinsicable[str] | None, stage_variables: ( dict[str, Intrinsicable[str]] | None ), # I tried to keep presence of = None consistent with http @@ -295,14 +295,23 @@ def _construct_permission(self, route_key: str, perms_id: str, route_spec: dict[ perms.Action = "lambda:InvokeFunction" perms.FunctionName = route_spec["FunctionArn"] perms.Principal = "apigateway.amazonaws.com" - perms.SourceArn = fnSub( - "arn:${AWS::Partition}:execute-api:${AWS::Region}:${AWS::AccountId}:${" - + self.logical_id - + ".ApiId}/" - + self.stage_name - + "/" - + route_key - ) + if isinstance(self.stage_name, str): + perms.SourceArn = fnSub( + "arn:${AWS::Partition}:execute-api:${AWS::Region}:${AWS::AccountId}:${" + + self.logical_id + + ".ApiId}/" + + self.stage_name + + "/" + + route_key + ) + else: + perms.SourceArn = fnSub( + "arn:${AWS::Partition}:execute-api:${AWS::Region}:${AWS::AccountId}:${" + + self.logical_id + + ".ApiId}/${__StageName__}/" + + route_key, + {"__StageName__": self.stage_name}, + ) return perms def _construct_route_infr(self, route_key: str, route_spec: dict[str, Any]) -> tuple[ diff --git a/samtranslator/model/sam_resources.py b/samtranslator/model/sam_resources.py index 7fccf78b5..0a3a9ba29 100644 --- a/samtranslator/model/sam_resources.py +++ b/samtranslator/model/sam_resources.py @@ -1910,7 +1910,7 @@ class SamWebSocketApi(SamResourceMacro): "Routes": PropertyType(True, IS_DICT), "RouteSettings": PropertyType(False, IS_DICT), "RouteSelectionExpression": PropertyType(True, IS_STR), - "StageName": PropertyType(False, IS_STR), + "StageName": PropertyType(False, one_of(IS_STR, IS_DICT)), "StageVariables": PropertyType(False, IS_DICT), "Tags": PropertyType(False, IS_DICT), } @@ -1930,7 +1930,7 @@ class SamWebSocketApi(SamResourceMacro): Routes: dict[str, dict[str, Any]] RouteSettings: dict[str, Any] | None RouteSelectionExpression: str - StageName: str | None + StageName: Intrinsicable[str] | None StageVariables: dict[str, Intrinsicable[str]] | None Tags: dict[str, Any] | None diff --git a/tests/model/api/test_websocket_api_generator.py b/tests/model/api/test_websocket_api_generator.py index ec2a5434c..7a37a4786 100644 --- a/tests/model/api/test_websocket_api_generator.py +++ b/tests/model/api/test_websocket_api_generator.py @@ -47,6 +47,16 @@ def test_perms(self): "arn:${AWS::Partition}:execute-api:${AWS::Region}:${AWS::AccountId}:${WebSocketApiId.ApiId}/default/$connect", ) + def test_perms_with_intrinsic_stage_name(self): + """Test that _construct_permission handles intrinsic StageName without TypeError.""" + kwargs = self.kwargs.copy() + kwargs["stage_name"] = {"Ref": "StageName"} + _, _, perm, _ = WebSocketApiGenerator(**kwargs)._construct_route_infr("$connect", kwargs["routes"]["$connect"]) + fn_sub = perm.SourceArn["Fn::Sub"] + self.assertIsInstance(fn_sub, list) + self.assertIn("${__StageName__}", fn_sub[0]) + self.assertEqual(fn_sub[1]["__StageName__"], {"Ref": "StageName"}) + def test_none_auth_no_id(self): kwargs = self.kwargs.copy() kwargs["auth_config"] = {"AuthType": "NONE"}