Skip to content

Commit

Permalink
Add AWS operators to create and delete RDS Database (#24099)
Browse files Browse the repository at this point in the history
* Add RdsCreateDbInstanceOperator

* Add RdsDeleteDbInstanceOperator
  • Loading branch information
eskarimov authored Jun 28, 2022
1 parent c7feb31 commit bf72752
Show file tree
Hide file tree
Showing 6 changed files with 350 additions and 72 deletions.
129 changes: 58 additions & 71 deletions airflow/providers/amazon/aws/example_dags/example_dms.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@
DmsStartTaskOperator,
DmsStopTaskOperator,
)
from airflow.providers.amazon.aws.operators.rds import (
RdsCreateDbInstanceOperator,
RdsDeleteDbInstanceOperator,
)
from airflow.providers.amazon.aws.sensors.dms import DmsTaskBaseSensor, DmsTaskCompletedSensor

S3_BUCKET = os.getenv('S3_BUCKET', 's3_bucket_name')
Expand Down Expand Up @@ -109,29 +113,20 @@
}


def _create_rds_instance():
print('Creating RDS Instance.')

def _get_rds_instance_endpoint():
print('Retrieving RDS instance endpoint.')
rds_client = boto3.client('rds')
rds_client.create_db_instance(
DBName=RDS_DB_NAME,
DBInstanceIdentifier=RDS_INSTANCE_NAME,
AllocatedStorage=20,
DBInstanceClass='db.t3.micro',
Engine=RDS_ENGINE,
MasterUsername=RDS_USERNAME,
MasterUserPassword=RDS_PASSWORD,
)

rds_client.get_waiter('db_instance_available').wait(DBInstanceIdentifier=RDS_INSTANCE_NAME)

response = rds_client.describe_db_instances(DBInstanceIdentifier=RDS_INSTANCE_NAME)
return response['DBInstances'][0]['Endpoint']
rds_instance_endpoint = response['DBInstances'][0]['Endpoint']
return rds_instance_endpoint


def _create_rds_table(rds_endpoint):
print('Creating table.')
@task
def create_sample_table():
print('Creating sample table.')

rds_endpoint = _get_rds_instance_endpoint()
hostname = rds_endpoint['Address']
port = rds_endpoint['Port']
rds_url = f'{RDS_PROTOCOL}://{RDS_USERNAME}:{RDS_PASSWORD}@{hostname}:{port}/{RDS_DB_NAME}'
Expand All @@ -154,18 +149,21 @@ def _create_rds_table(rds_endpoint):
connection.execute(table.select())


def _create_dms_replication_instance(ti, dms_client):
@task
def create_dms_assets():
print('Creating DMS assets.')
ti = get_current_context()['ti']
dms_client = boto3.client('dms')
rds_instance_endpoint = _get_rds_instance_endpoint()

print('Creating replication instance.')
instance_arn = dms_client.create_replication_instance(
ReplicationInstanceIdentifier=DMS_REPLICATION_INSTANCE_NAME,
ReplicationInstanceClass='dms.t3.micro',
)['ReplicationInstance']['ReplicationInstanceArn']

ti.xcom_push(key='replication_instance_arn', value=instance_arn)
return instance_arn


def _create_dms_endpoints(ti, dms_client, rds_instance_endpoint):
print('Creating DMS source endpoint.')
source_endpoint_arn = dms_client.create_endpoint(
EndpointIdentifier=SOURCE_ENDPOINT_IDENTIFIER,
Expand Down Expand Up @@ -194,28 +192,16 @@ def _create_dms_endpoints(ti, dms_client, rds_instance_endpoint):
ti.xcom_push(key='source_endpoint_arn', value=source_endpoint_arn)
ti.xcom_push(key='target_endpoint_arn', value=target_endpoint_arn)


def _await_setup_assets(dms_client, instance_arn):
print("Awaiting asset provisioning.")
print("Awaiting replication instance provisioning.")
dms_client.get_waiter('replication_instance_available').wait(
Filters=[{'Name': 'replication-instance-arn', 'Values': [instance_arn]}]
)


def _delete_rds_instance():
print('Deleting RDS Instance.')

rds_client = boto3.client('rds')
rds_client.delete_db_instance(
DBInstanceIdentifier=RDS_INSTANCE_NAME,
SkipFinalSnapshot=True,
)

rds_client.get_waiter('db_instance_deleted').wait(DBInstanceIdentifier=RDS_INSTANCE_NAME)


def _delete_dms_assets(dms_client):
@task(trigger_rule='all_done')
def delete_dms_assets():
ti = get_current_context()['ti']
dms_client = boto3.client('dms')
replication_instance_arn = ti.xcom_pull(key='replication_instance_arn')
source_arn = ti.xcom_pull(key='source_endpoint_arn')
target_arn = ti.xcom_pull(key='target_endpoint_arn')
Expand All @@ -225,13 +211,10 @@ def _delete_dms_assets(dms_client):
dms_client.delete_endpoint(EndpointArn=source_arn)
dms_client.delete_endpoint(EndpointArn=target_arn)


def _await_all_teardowns(dms_client):
print('Awaiting tear-down.')
print('Awaiting DMS assets tear-down.')
dms_client.get_waiter('replication_instance_deleted').wait(
Filters=[{'Name': 'replication-instance-id', 'Values': [DMS_REPLICATION_INSTANCE_NAME]}]
)

dms_client.get_waiter('endpoint_deleted').wait(
Filters=[
{
Expand All @@ -242,27 +225,6 @@ def _await_all_teardowns(dms_client):
)


@task
def set_up():
ti = get_current_context()['ti']
dms_client = boto3.client('dms')

rds_instance_endpoint = _create_rds_instance()
_create_rds_table(rds_instance_endpoint)
instance_arn = _create_dms_replication_instance(ti, dms_client)
_create_dms_endpoints(ti, dms_client, rds_instance_endpoint)
_await_setup_assets(dms_client, instance_arn)


@task(trigger_rule='all_done')
def clean_up():
dms_client = boto3.client('dms')

_delete_rds_instance()
_delete_dms_assets(dms_client)
_await_all_teardowns(dms_client)


with DAG(
dag_id='example_dms',
schedule_interval=None,
Expand All @@ -271,6 +233,19 @@ def clean_up():
catchup=False,
) as dag:

create_db_instance = RdsCreateDbInstanceOperator(
task_id="create_db_instance",
db_instance_identifier=RDS_INSTANCE_NAME,
db_instance_class='db.t3.micro',
engine=RDS_ENGINE,
rds_kwargs={
"DBName": RDS_DB_NAME,
"AllocatedStorage": 20,
"MasterUsername": RDS_USERNAME,
"MasterUserPassword": RDS_PASSWORD,
},
)

# [START howto_operator_dms_create_task]
create_task = DmsCreateTaskOperator(
task_id='create_task',
Expand Down Expand Up @@ -334,14 +309,26 @@ def clean_up():
)
# [END howto_operator_dms_delete_task]

delete_db_instance = RdsDeleteDbInstanceOperator(
task_id='delete_db_instance',
db_instance_identifier=RDS_INSTANCE_NAME,
rds_kwargs={
"SkipFinalSnapshot": True,
},
trigger_rule='all_done',
)

chain(
set_up()
>> create_task
>> start_task
>> describe_tasks
>> await_task_start
>> stop_task
>> await_task_stop
>> delete_task
>> clean_up()
create_db_instance,
create_sample_table(),
create_dms_assets(),
create_task,
start_task,
describe_tasks,
await_task_start,
stop_task,
await_task_stop,
delete_task,
delete_dms_assets(),
delete_db_instance,
)
106 changes: 105 additions & 1 deletion airflow/providers/amazon/aws/operators/rds.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import json
import time
from typing import TYPE_CHECKING, List, Optional, Sequence
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence

from mypy_boto3_rds.type_defs import TagTypeDef

Expand Down Expand Up @@ -551,6 +551,108 @@ def execute(self, context: 'Context') -> str:
return json.dumps(delete_subscription, default=str)


class RdsCreateDbInstanceOperator(RdsBaseOperator):
"""
Creates an RDS DB instance
.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:RdsCreateDbInstanceOperator`
:param db_instance_identifier: The DB instance identifier, must start with a letter and
contain from 1 to 63 letters, numbers, or hyphens
:param db_instance_class: The compute and memory capacity of the DB instance, for example db.m5.large
:param engine: The name of the database engine to be used for this instance
:param rds_kwargs: Named arguments to pass to boto3 RDS client function ``create_db_instance``
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.create_db_instance
:param aws_conn_id: The Airflow connection used for AWS credentials.
:param wait_for_completion: Whether or not wait for creation of the DB instance to
complete. (default: True)
"""

def __init__(
self,
*,
db_instance_identifier: str,
db_instance_class: str,
engine: str,
rds_kwargs: Optional[Dict] = None,
aws_conn_id: str = "aws_default",
wait_for_completion: bool = True,
**kwargs,
):
super().__init__(aws_conn_id=aws_conn_id, **kwargs)

self.db_instance_identifier = db_instance_identifier
self.db_instance_class = db_instance_class
self.engine = engine
self.rds_kwargs = rds_kwargs or {}
self.wait_for_completion = wait_for_completion

def execute(self, context: 'Context') -> str:
self.log.info("Creating new DB instance %s", self.db_instance_identifier)

create_db_instance = self.hook.conn.create_db_instance(
DBInstanceIdentifier=self.db_instance_identifier,
DBInstanceClass=self.db_instance_class,
Engine=self.engine,
**self.rds_kwargs,
)

if self.wait_for_completion:
self.hook.conn.get_waiter("db_instance_available").wait(
DBInstanceIdentifier=self.db_instance_identifier
)

return json.dumps(create_db_instance, default=str)


class RdsDeleteDbInstanceOperator(RdsBaseOperator):
"""
Deletes an RDS DB Instance
.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:RdsDeleteDbInstanceOperator`
:param db_instance_identifier: The DB instance identifier for the DB instance to be deleted
:param rds_kwargs: Named arguments to pass to boto3 RDS client function ``delete_db_instance``
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.delete_db_instance
:param aws_conn_id: The Airflow connection used for AWS credentials.
:param wait_for_completion: Whether or not wait for deletion of the DB instance to
complete. (default: True)
"""

def __init__(
self,
*,
db_instance_identifier: str,
rds_kwargs: Optional[Dict] = None,
aws_conn_id: str = "aws_default",
wait_for_completion: bool = True,
**kwargs,
):
super().__init__(aws_conn_id=aws_conn_id, **kwargs)
self.db_instance_identifier = db_instance_identifier
self.rds_kwargs = rds_kwargs or {}
self.wait_for_completion = wait_for_completion

def execute(self, context: 'Context') -> str:
self.log.info("Deleting DB instance %s", self.db_instance_identifier)

delete_db_instance = self.hook.conn.delete_db_instance(
DBInstanceIdentifier=self.db_instance_identifier,
**self.rds_kwargs,
)

if self.wait_for_completion:
self.hook.conn.get_waiter("db_instance_deleted").wait(
DBInstanceIdentifier=self.db_instance_identifier
)

return json.dumps(delete_db_instance, default=str)


__all__ = [
"RdsCreateDbSnapshotOperator",
"RdsCopyDbSnapshotOperator",
Expand All @@ -559,4 +661,6 @@ def execute(self, context: 'Context') -> str:
"RdsDeleteEventSubscriptionOperator",
"RdsStartExportTaskOperator",
"RdsCancelExportTaskOperator",
"RdsCreateDbInstanceOperator",
"RdsDeleteDbInstanceOperator",
]
28 changes: 28 additions & 0 deletions docs/apache-airflow-providers-amazon/operators/rds.rst
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,34 @@ To delete an Amazon RDS event subscription you can use
:start-after: [START howto_operator_rds_delete_event_subscription]
:end-before: [END howto_operator_rds_delete_event_subscription]

.. _howto/operator:RdsCreateDbInstanceOperator:

Create a database instance
==========================

To create a AWS DB instance you can use
:class:`~airflow.providers.amazon.aws.operators.rds.RdsCreateDbInstanceOperator`.

.. exampleinclude:: /../../tests/system/providers/amazon/aws/rds/example_rds_instance.py
:language: python
:dedent: 4
:start-after: [START howto_operator_rds_create_db_instance]
:end-before: [END howto_operator_rds_create_db_instance]

.. _howto/operator:RDSDeleteDbInstanceOperator:

Delete a database instance
==========================

To delete a AWS DB instance you can use
:class:`~airflow.providers.amazon.aws.operators.rds.RDSDeleteDbInstanceOperator`.

.. exampleinclude:: /../../tests/system/providers/amazon/aws/rds/example_rds_instance.py
:language: python
:dedent: 4
:start-after: [START howto_operator_rds_delete_db_instance]
:end-before: [END howto_operator_rds_delete_db_instance]

Sensors
-------

Expand Down
Loading

0 comments on commit bf72752

Please sign in to comment.