Skip to content

Commit

Permalink
Merge pull request #202 from adanaja/feat
Browse files Browse the repository at this point in the history
Add back possibility to have multiple domains
  • Loading branch information
adanaja authored Mar 28, 2023
2 parents 81f02ac + 805098a commit 3d1ee29
Show file tree
Hide file tree
Showing 4 changed files with 405 additions and 108 deletions.
125 changes: 76 additions & 49 deletions src/e3/aws/troposphere/apigateway/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations
from enum import Enum
from functools import cached_property
from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING
from abc import abstractmethod
from e3.aws import name_to_id
from e3.aws.troposphere import Construct
Expand All @@ -27,7 +27,7 @@

if TYPE_CHECKING:
from e3.aws.troposphere import Stack
from typing import Any
from typing import Any, TypedDict


class AuthorizationType(Enum):
Expand Down Expand Up @@ -161,6 +161,12 @@ def __init__(
class Api(Construct):
"""API abstact Class for APIGateways V1 and V2."""

if TYPE_CHECKING:

class _AliasTargetAttributes(TypedDict):
DNSName: str
HostedZoneId: str

def __init__(
self,
name: str,
Expand Down Expand Up @@ -278,37 +284,41 @@ def declare_stage(
"""
pass

def declare_certificate(self) -> Certificate:
def _declare_certificate(
self, domain_name: str, hosted_zone_id: str
) -> Certificate:
"""Declare the API's domain certificate.
:param domain_name: domain name
:param hosted_zone_id: hosted zone in which the domain belongs to
:return: domain name certificate
"""
return Certificate(
name_to_id(self.name + cast(str, self.domain_name) + "Certificate"),
DomainName=self.domain_name,
name_to_id(self.name + domain_name + "Certificate"),
DomainName=domain_name,
DomainValidationOptions=[
DomainValidationOption(
DomainName=self.domain_name, HostedZoneId=self.hosted_zone_id
DomainName=domain_name, HostedZoneId=hosted_zone_id
)
],
ValidationMethod="DNS",
)

@abstractmethod
def declare_domain_name(
self, domain_name_id: str, certificate_arn: Ref | str
def _declare_domain_name(
self, domain_name: str, certificate_arn: Ref | str
) -> apigatewayv2.DomainName | apigateway.DomainName:
"""Declare the domain name aws resource of the API.
:param domain_name_id: the domain name's logical id
:param domain_name: domain name
:param certificate_arn: the ARN of the certificate
:return: the domain name aws resource
"""
pass

@abstractmethod
def declare_api_mapping(
self, domain_name: Ref | str
def _declare_api_mapping(
self, domain_name: apigatewayv2.DomainName | apigateway.DomainName
) -> list[BasePathMapping | ApiMapping]:
"""Declare the API's mapping.
Expand All @@ -317,41 +327,50 @@ def declare_api_mapping(
"""
pass

def declare_domain(self, alias_target: dict[str, str]) -> list[AWSObject]:
@abstractmethod
def _get_alias_target_attributes(self) -> Api._AliasTargetAttributes:
"""Get atributes to pass to GetAtt for alias target."""
pass

def declare_domain(self, domain_name: str, hosted_zone_id: str) -> list[AWSObject]:
"""Declare a custom domain for the API stages.
Note that when a custom domain is created then a certificate is automatically
created for that domain.
:param alias_target: atributes to GetAtt for alias target
:param domain_name: domain name
:param hosted_zone_id: hosted zone in which the domain belongs to
:return: a list of AWSObject
"""
result = []
domain_name_id = name_to_id(self.name + cast(str, self.domain_name) + "Domain")

certificate = self.declare_certificate()
certificate = self._declare_certificate(
domain_name=domain_name, hosted_zone_id=hosted_zone_id
)
result.append(certificate)

domain = self.declare_domain_name(
domain_name_id=domain_name_id, certificate_arn=certificate.ref()
domain = self._declare_domain_name(
domain_name=domain_name,
certificate_arn=certificate.ref(),
)
result.append(domain)

result += self.declare_api_mapping(domain.ref())
result += self._declare_api_mapping(domain)
alias_target = self._get_alias_target_attributes()

result.append(
route53.RecordSetType(
name_to_id(self.name + cast(str, self.domain_name) + "DNS"),
Name=self.domain_name,
name_to_id(self.name + domain_name + "DNS"),
Name=domain_name,
Type="A",
HostedZoneId=self.hosted_zone_id,
HostedZoneId=hosted_zone_id,
AliasTarget=route53.AliasTarget(
DNSName=GetAtt(
domain_name_id,
domain.title,
alias_target["DNSName"],
),
HostedZoneId=GetAtt(
domain_name_id,
domain.title,
alias_target["HostedZoneId"],
),
EvaluateTargetHealth=False,
Expand Down Expand Up @@ -543,25 +562,25 @@ def declare_route(self, route: Route, integration: Ref | str) -> list[AWSObject]
)
return result

def declare_domain_name(
self, domain_name_id: str, certificate_arn: Ref | str
def _declare_domain_name(
self, domain_name: str, certificate_arn: Ref | str
) -> apigatewayv2.DomainName | apigateway.DomainName:
"""Declare the domain name aws resource of the API.
:param domain_name_id: the domain name's name id
:param domain_name: domain name
:param certificate_arn: the ARN of the certificate
:return: a domain name aws resource
"""
return apigatewayv2.DomainName(
domain_name_id,
DomainName=self.domain_name,
name_to_id(self.name + domain_name + "Domain"),
DomainName=domain_name,
DomainNameConfigurations=[
apigatewayv2.DomainNameConfiguration(CertificateArn=certificate_arn)
],
)

def declare_api_mapping(
self, domain_name: Ref | str
def _declare_api_mapping(
self, domain_name: apigatewayv2.DomainName | apigateway.DomainName
) -> list[BasePathMapping | ApiMapping]:
"""Declare the API's mapping.
Expand All @@ -571,7 +590,7 @@ def declare_api_mapping(
result = []
for config in self.stages_config:
mapping_params = {
"DomainName": domain_name,
"DomainName": domain_name.ref(),
"Stage": self.stage_ref(config.name),
"ApiId": self.ref,
}
Expand All @@ -584,7 +603,7 @@ def declare_api_mapping(
name_to_id(
"{}{}-{}ApiMapping".format(
self.name,
self.domain_name,
domain_name.DomainName,
"" if config.name == "$default" else config.name,
)
),
Expand All @@ -593,6 +612,13 @@ def declare_api_mapping(
)
return result

def _get_alias_target_attributes(self) -> Api._AliasTargetAttributes:
"""Get atributes to pass to GetAtt for alias target."""
return {
"DNSName": "RegionalDomainName",
"HostedZoneId": "RegionalHostedZoneId",
}

def resources(self, stack: Stack) -> list[AWSObject]:
"""Return list of AWSObject associated with the construct."""
result = []
Expand Down Expand Up @@ -640,10 +666,7 @@ def resources(self, stack: Stack) -> list[AWSObject]:
if self.domain_name is not None:
assert self.hosted_zone_id is not None
result += self.declare_domain(
alias_target={
"DNSName": "RegionalDomainName",
"HostedZoneId": "RegionalHostedZoneId",
}
domain_name=self.domain_name, hosted_zone_id=self.hosted_zone_id
)

# Declare the authorizers
Expand Down Expand Up @@ -888,23 +911,23 @@ def declare_method(self, method: Method, resource_id: Ref) -> list[AWSObject]:
)
return result

def declare_domain_name(
self, domain_name_id: str, certificate_arn: Ref | str
def _declare_domain_name(
self, domain_name: str, certificate_arn: Ref | str
) -> apigatewayv2.DomainName | apigateway.DomainName:
"""Declare the domain name aws resource of the API.
:param domain_name_id: the domain name's name id
:param domain_name: domain name
:param certificate_arn: the ARN of the certificate
:return: a domain name aws resource
"""
return apigateway.DomainName(
domain_name_id,
DomainName=self.domain_name,
name_to_id(self.name + domain_name + "Domain"),
DomainName=domain_name,
CertificateArn=certificate_arn,
)

def declare_api_mapping(
self, domain_name: Ref | str
def _declare_api_mapping(
self, domain_name: apigatewayv2.DomainName | apigateway.DomainName
) -> list[BasePathMapping | ApiMapping]:
"""Declare the API's mapping.
Expand All @@ -914,7 +937,7 @@ def declare_api_mapping(
result = []
for config in self.stages_config:
mapping_params = {
"DomainName": domain_name,
"DomainName": domain_name.ref(),
"Stage": self.stage_ref(config.name),
"RestApiId": self.ref,
}
Expand All @@ -927,7 +950,7 @@ def declare_api_mapping(
name_to_id(
"{}{}-{}BasePathMapping".format(
self.name,
self.domain_name,
domain_name.DomainName,
"" if config.name == "$default" else config.name,
)
),
Expand All @@ -936,6 +959,13 @@ def declare_api_mapping(
)
return result

def _get_alias_target_attributes(self) -> Api._AliasTargetAttributes:
"""Get atributes to pass to GetAtt for alias target."""
return {
"DNSName": "DistributionDomainName",
"HostedZoneId": "DistributionHostedZoneId",
}

def declare_access_cloudwatch_resources(self) -> list[AWSObject]:
"""Create role and account resources to enable CloudWatch Logs.
Expand Down Expand Up @@ -1020,10 +1050,7 @@ def resources(self, stack: Stack) -> list[AWSObject]:
if self.domain_name is not None:
assert self.hosted_zone_id is not None
result += self.declare_domain(
alias_target={
"DNSName": "DistributionDomainName",
"HostedZoneId": "DistributionHostedZoneId",
}
domain_name=self.domain_name, hosted_zone_id=self.hosted_zone_id
)

# Declare the authorizers
Expand Down
Loading

0 comments on commit 3d1ee29

Please sign in to comment.