Skip to content

Commit

Permalink
Refactor GlueJobHook get_or_create_glue_job method. (#24215)
Browse files Browse the repository at this point in the history
When invoked, create_job takes into account the provided 'Command' argument instead of having it hardcoded.
  • Loading branch information
gmcrocetti authored Jun 6, 2022
1 parent fd4e344 commit 41898d8
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 14 deletions.
10 changes: 8 additions & 2 deletions airflow/providers/amazon/aws/hooks/glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,14 +181,20 @@ def get_or_create_glue_job(self) -> str:
s3_log_path = f's3://{self.s3_bucket}/{self.s3_glue_logs}{self.job_name}'
execution_role = self.get_iam_execution_role()
try:
default_command = {
"Name": "glueetl",
"ScriptLocation": self.script_location,
}
command = self.create_job_kwargs.get("Command", default_command)

if "WorkerType" in self.create_job_kwargs and "NumberOfWorkers" in self.create_job_kwargs:
create_job_response = glue_client.create_job(
Name=self.job_name,
Description=self.desc,
LogUri=s3_log_path,
Role=execution_role['Role']['Arn'],
ExecutionProperty={"MaxConcurrentRuns": self.concurrent_run_limit},
Command={"Name": "glueetl", "ScriptLocation": self.script_location},
Command=command,
MaxRetries=self.retry_limit,
**self.create_job_kwargs,
)
Expand All @@ -199,7 +205,7 @@ def get_or_create_glue_job(self) -> str:
LogUri=s3_log_path,
Role=execution_role['Role']['Arn'],
ExecutionProperty={"MaxConcurrentRuns": self.concurrent_run_limit},
Command={"Name": "glueetl", "ScriptLocation": self.script_location},
Command=command,
MaxRetries=self.retry_limit,
MaxCapacity=self.num_of_dpus,
**self.create_job_kwargs,
Expand Down
57 changes: 45 additions & 12 deletions tests/providers/amazon/aws/hooks/test_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
from airflow.providers.amazon.aws.hooks.glue import GlueJobHook

try:
from moto import mock_iam
from moto import mock_glue, mock_iam
except ImportError:
mock_iam = None
mock_iam = mock_glue = None


class TestGlueJobHook(unittest.TestCase):
Expand Down Expand Up @@ -57,23 +57,56 @@ def test_get_iam_execution_role(self):
assert "Arn" in iam_role['Role']
assert iam_role['Role']['Arn'] == "arn:aws:iam::123456789012:role/my_test_role"

@mock.patch.object(GlueJobHook, "get_iam_execution_role")
@mock.patch.object(GlueJobHook, "get_conn")
def test_get_or_create_glue_job(self, mock_get_conn, mock_get_iam_execution_role):
mock_get_iam_execution_role.return_value = mock.MagicMock(Role={'RoleName': 'my_test_role'})
def test_get_or_create_glue_job_get_existing_job(self, mock_get_conn):
"""
Calls 'get_or_create_glue_job' with a existing job.
Should retrieve existing one.
"""
expected_job_name = "simple-job"
mock_get_conn.return_value.get_job.return_value = {"Job": {"Name": expected_job_name}}

some_script = "s3:/glue-examples/glue-scripts/sample_aws_glue_job.py"
some_s3_bucket = "my-includes"

mock_glue_job = mock_get_conn.return_value.get_job()['Job']['Name']
glue_job = GlueJobHook(
job_name='aws_test_glue_job',
desc='This is test case job from Airflow',
hook = GlueJobHook(
job_name="aws_test_glue_job",
desc="This is test case job from Airflow",
script_location=some_script,
iam_role_name='my_test_role',
iam_role_name="my_test_role",
s3_bucket=some_s3_bucket,
region_name=self.some_aws_region,
).get_or_create_glue_job()
assert glue_job == mock_glue_job
)

result = hook.get_or_create_glue_job()

mock_get_conn.assert_called_once()
mock_get_conn.return_value.get_job.assert_called_once_with(JobName=hook.job_name)
assert result == expected_job_name

@unittest.skipIf(mock_glue is None, "mock_glue package not present")
@mock_glue
@mock.patch.object(GlueJobHook, "get_iam_execution_role")
def test_get_or_create_glue_job_create_new_job(self, mock_get_iam_execution_role):
"""
Calls 'get_or_create_glue_job' with no existing job.
Should create a new job.
"""
mock_get_iam_execution_role.return_value = {"Role": {"RoleName": "my_test_role", "Arn": "test_role"}}
expected_job_name = "aws_test_glue_job"

hook = GlueJobHook(
job_name=expected_job_name,
desc="This is test case job from Airflow",
iam_role_name="my_test_role",
script_location="s3://bucket",
s3_bucket="bucket",
region_name=self.some_aws_region,
)

result = hook.get_or_create_glue_job()

assert result == expected_job_name

@mock.patch.object(GlueJobHook, "get_iam_execution_role")
@mock.patch.object(GlueJobHook, "get_conn")
Expand Down

0 comments on commit 41898d8

Please sign in to comment.