From 2b9c20f0a58564e12c5a3ab4290cc1572acb6a1d Mon Sep 17 00:00:00 2001 From: Ping Zhang Date: Thu, 12 May 2022 10:51:16 -0700 Subject: [PATCH] Fallback to parse dag_file when no dag in the db --- airflow/cli/commands/task_command.py | 6 ++++- tests/cli/commands/test_task_command.py | 32 +++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/airflow/cli/commands/task_command.py b/airflow/cli/commands/task_command.py index 78e8fc20f6e49..e054e4575cc71 100644 --- a/airflow/cli/commands/task_command.py +++ b/airflow/cli/commands/task_command.py @@ -358,7 +358,11 @@ def task_run(args, dag=None): dag = get_dag_by_pickle(args.pickle) elif not dag: if args.local: - dag = get_dag_by_deserialization(args.dag_id) + try: + dag = get_dag_by_deserialization(args.dag_id) + except AirflowException: + print(f'DAG {args.dag_id} does not exist in the database, trying to parse the dag_file') + dag = get_dag(args.subdir, args.dag_id) else: dag = get_dag(args.subdir, args.dag_id) else: diff --git a/tests/cli/commands/test_task_command.py b/tests/cli/commands/test_task_command.py index fac04fd50802d..4822816b48cc1 100644 --- a/tests/cli/commands/test_task_command.py +++ b/tests/cli/commands/test_task_command.py @@ -152,6 +152,38 @@ def test_run_get_serialized_dag(self, mock_local_job, mock_get_dag_by_deserializ ) mock_get_dag_by_deserialization.assert_called_once_with(self.dag_id) + @mock.patch("airflow.cli.commands.task_command.get_dag_by_deserialization") + @mock.patch("airflow.cli.commands.task_command.LocalTaskJob") + def test_run_get_serialized_dag_fallback(self, mock_local_job, mock_get_dag_by_deserialization): + """ + Fallback to parse dag_file when serialized dag does not exist in the db + """ + task_id = self.dag.task_ids[0] + args = [ + 'tasks', + 'run', + '--ignore-all-dependencies', + '--local', + self.dag_id, + task_id, + self.run_id, + ] + mock_get_dag_by_deserialization.side_effect = mock.Mock(side_effect=AirflowException('Not found')) + + task_command.task_run(self.parser.parse_args(args)) + mock_local_job.assert_called_once_with( + task_instance=mock.ANY, + mark_success=False, + ignore_all_deps=True, + ignore_depends_on_past=False, + ignore_task_deps=False, + ignore_ti_state=False, + pickle_id=None, + pool=None, + external_executor_id=None, + ) + mock_get_dag_by_deserialization.assert_called_once_with(self.dag_id) + @mock.patch("airflow.cli.commands.task_command.LocalTaskJob") def test_run_with_existing_dag_run_id(self, mock_local_job): """