From d213759b54477766a2c7cf54398cb009ac39f82c Mon Sep 17 00:00:00 2001 From: Chris Redekop Date: Thu, 12 May 2022 12:34:26 -0600 Subject: [PATCH 1/8] Add stalled celery task handling for all tasks, rather than just adopted tasks --- airflow/config_templates/config.yml | 15 ++- airflow/config_templates/default_airflow.cfg | 11 +- airflow/executors/celery_executor.py | 106 ++++++++++------ tests/executors/test_celery_executor.py | 120 ++++++++++++------- 4 files changed, 171 insertions(+), 81 deletions(-) diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml index 89479b9b62b15f..7cff1bc4032dcd 100644 --- a/airflow/config_templates/config.yml +++ b/airflow/config_templates/config.yml @@ -1768,12 +1768,23 @@ default: "True" - name: task_adoption_timeout description: | - Time in seconds after which Adopted tasks are cleared by CeleryExecutor. This is helpful to clear - stalled tasks. + Time in seconds after which adopted tasks which are queued in celery are assumed to be stalled, + and are automatically cleared. This setting does the same thing as ``stalled_task_timeout`` but + applies specifically to adopted tasks only. When set to 0, the ``stalled_task_timeout`` setting + also applies to adopted tasks. version_added: 2.0.0 type: integer example: ~ default: "600" + - name: stalled_task_timeout + description: | + Time in seconds after which tasks queued in celery are assumed to be stalled, and are automatically + cleared. Adopted tasks will instead use the ``task_adoption_timeout`` setting if specified. + When set to 0, automatic clearing of stalled tasks is disabled. + version_added: 2.3.1 + type: integer + example: ~ + default: "0" - name: task_publish_max_retries description: | The Maximum number of retries for publishing task messages to the broker when failing diff --git a/airflow/config_templates/default_airflow.cfg b/airflow/config_templates/default_airflow.cfg index a457ea6fa96a9e..3ca6528d8df604 100644 --- a/airflow/config_templates/default_airflow.cfg +++ b/airflow/config_templates/default_airflow.cfg @@ -888,10 +888,17 @@ operation_timeout = 1.0 # or run in HA mode, it can adopt the orphan tasks launched by previous SchedulerJob. task_track_started = True -# Time in seconds after which Adopted tasks are cleared by CeleryExecutor. This is helpful to clear -# stalled tasks. +# Time in seconds after which adopted tasks which are queued in celery are assumed to be stalled, +# and are automatically cleared. This setting does the same thing as ``stalled_task_timeout`` but +# applies specifically to adopted tasks only. When set to 0, the ``stalled_task_timeout`` setting +# also applies to adopted tasks. task_adoption_timeout = 600 +# Time in seconds after which tasks queued in celery are assumed to be stalled, and are automatically +# cleared. Adopted tasks will instead use the ``task_adoption_timeout`` setting if specified. +# When set to 0, automatic clearing of stalled tasks is disabled. +stalled_task_timeout = 0 + # The Maximum number of retries for publishing task messages to the broker when failing # due to ``AirflowTaskTimeout`` error before giving up and marking Task as failed. task_publish_max_retries = 3 diff --git a/airflow/executors/celery_executor.py b/airflow/executors/celery_executor.py index 752244d5528c6b..e7f1e55a716cb4 100644 --- a/airflow/executors/celery_executor.py +++ b/airflow/executors/celery_executor.py @@ -29,7 +29,7 @@ import subprocess import time import traceback -from collections import Counter, OrderedDict +from collections import Counter from concurrent.futures import ProcessPoolExecutor from multiprocessing import cpu_count from typing import Any, Dict, List, Mapping, MutableMapping, Optional, Set, Tuple, Union @@ -40,6 +40,7 @@ from celery.result import AsyncResult from celery.signals import import_modules as celery_import_modules from setproctitle import setproctitle +from sqlalchemy.orm.session import Session import airflow.settings as settings from airflow.config_templates.default_celery import DEFAULT_CELERY_CONFIG @@ -50,6 +51,7 @@ from airflow.stats import Stats from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.net import get_hostname +from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.state import State from airflow.utils.timeout import timeout from airflow.utils.timezone import utcnow @@ -230,11 +232,13 @@ def __init__(self): self._sync_parallelism = max(1, cpu_count() - 1) self.bulk_state_fetcher = BulkStateFetcher(self._sync_parallelism) self.tasks = {} - # Mapping of tasks we've adopted, ordered by the earliest date they timeout - self.adopted_task_timeouts: Dict[TaskInstanceKey, datetime.datetime] = OrderedDict() + self.stalled_task_timeouts: Dict[TaskInstanceKey, datetime.datetime] = {} self.task_adoption_timeout = datetime.timedelta( seconds=conf.getint('celery', 'task_adoption_timeout', fallback=600) ) + self.stalled_task_timeout = datetime.timedelta( + seconds=conf.getint('celery', 'stalled_task_timeout', fallback=0) + ) self.task_publish_retries: Counter[TaskInstanceKey] = Counter() self.task_publish_max_retries = conf.getint('celery', 'task_publish_max_retries', fallback=3) @@ -285,6 +289,7 @@ def _process_tasks(self, task_tuples: List[TaskTuple]) -> None: result.backend = cached_celery_backend self.running.add(key) self.tasks[key] = result + self._set_stalled_task_timeout(key, self.stalled_task_timeout) # Store the Celery task_id in the event buffer. This will get "overwritten" if the task # has another event, but that is fine, because the only other events are success/failed at @@ -315,25 +320,23 @@ def sync(self) -> None: self.log.debug("No task to query celery, skipping sync") return self.update_all_task_states() + self._check_for_stalled_tasks() - if self.adopted_task_timeouts: - self._check_for_stalled_adopted_tasks() - - def _check_for_stalled_adopted_tasks(self): + @provide_session + def _check_for_stalled_tasks(self, session: Session = NEW_SESSION): """ - See if any of the tasks we adopted from another Executor run have not - progressed after the configured timeout. - - If they haven't, they likely never made it to Celery, and we should - just resend them. We do that by clearing the state and letting the - normal scheduler loop deal with that + Check to see if any of our tasks have not progressed in the expected + time. This can happen for few different reasons, usually related to + race conditions while shutting down schedulers and celery workers. + + It is, of course, always possible that these tasks are not actually + stalled - they could just be waiting in a long celery queue. + Unfortunately there's no way for us to know for sure, so we'll just + reschedule them and let the normal scheduler loop requeue them. """ now = utcnow() - - sorted_adopted_task_timeouts = sorted(self.adopted_task_timeouts.items(), key=lambda k: k[1]) - timedout_keys = [] - for key, stalled_after in sorted_adopted_task_timeouts: + for key, stalled_after in self.stalled_task_timeouts.items(): if stalled_after > now: # Since items are stored sorted, if we get to a stalled_after # in the future then we can stop @@ -343,20 +346,46 @@ def _check_for_stalled_adopted_tasks(self): # already finished, then it will be removed from this list -- so # the only time it's still in this list is when it a) never made it # to celery in the first place (i.e. race condition somewhere in - # the dying executor) or b) a really long celery queue and it just + # the dying executor), b) celery lost the task before execution + # started, or c) a really long celery queue and it just # hasn't started yet -- better cancel it and let the scheduler # re-queue rather than have this task risk stalling for ever timedout_keys.append(key) - if timedout_keys: - self.log.error( - "Adopted tasks were still pending after %s, assuming they never made it to celery and " - "clearing:\n\t%s", - self.task_adoption_timeout, - "\n\t".join(repr(x) for x in timedout_keys), - ) - for key in timedout_keys: - self.change_state(key, State.FAILED) + if not timedout_keys: + return + + self.log.error( + "Tasks were still pending after configured timeout (adopted: %s, all: %s), " + "assuming they never made it to celery and clearing:\n\t%s", + self.task_adoption_timeout, + self.stalled_task_timeout, + "\n\t".join(repr(x) for x in timedout_keys), + ) + + filter_for_tis = TaskInstance.filter_for_tis(timedout_keys) + session.query(TaskInstance).filter( + filter_for_tis, TaskInstance.state == State.QUEUED, TaskInstance.queued_by_job_id == self.job_id + ).update( + { + TaskInstance.state: State.SCHEDULED, + TaskInstance.queued_dttm: None, + TaskInstance.queued_by_job_id: None, + TaskInstance.external_executor_id: None, + }, + synchronize_session=False, + ) + session.commit() + + for key in timedout_keys: + self.stalled_task_timeouts.pop(key, None) + self.running.discard(key) + celery_async_result = self.tasks.pop(key, None) + if celery_async_result: + try: + app.control.revoke(celery_async_result.task_id) + except Exception as ex: + self.log.error("Error revoking task instance %s from celery: %s", key, ex) def debug_dump(self) -> None: """Called in response to SIGUSR2 by the scheduler""" @@ -365,9 +394,9 @@ def debug_dump(self) -> None: "executor.tasks (%d)\n\t%s", len(self.tasks), "\n\t".join(map(repr, self.tasks.items())) ) self.log.info( - "executor.adopted_task_timeouts (%d)\n\t%s", - len(self.adopted_task_timeouts), - "\n\t".join(map(repr, self.adopted_task_timeouts.items())), + "executor.stalled_task_timeouts (%d)\n\t%s", + len(self.stalled_task_timeouts), + "\n\t".join(map(repr, self.stalled_task_timeouts.items())), ) def update_all_task_states(self) -> None: @@ -384,7 +413,7 @@ def update_all_task_states(self) -> None: def change_state(self, key: TaskInstanceKey, state: str, info=None) -> None: super().change_state(key, state, info) self.tasks.pop(key, None) - self.adopted_task_timeouts.pop(key, None) + self.stalled_task_timeouts.pop(key, None) def update_task_state(self, key: TaskInstanceKey, state: str, info: Any) -> None: """Updates state of a single task.""" @@ -394,8 +423,8 @@ def update_task_state(self, key: TaskInstanceKey, state: str, info: Any) -> None elif state in (celery_states.FAILURE, celery_states.REVOKED): self.fail(key, info) elif state == celery_states.STARTED: - # It's now actually running, so know it made it to celery okay! - self.adopted_task_timeouts.pop(key, None) + # It's now actually running, so we know it made it to celery okay! + self.stalled_task_timeouts.pop(key, None) elif state == celery_states.PENDING: pass else: @@ -455,7 +484,7 @@ def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance # Set the correct elements of the state dicts, then update this # like we just queried it. - self.adopted_task_timeouts[ti.key] = ti.queued_dttm + self.task_adoption_timeout + self._set_stalled_task_timeout(ti.key, self.task_adoption_timeout or self.stalled_task_timeout) self.tasks[ti.key] = result self.running.add(ti.key) self.update_task_state(ti.key, state, info) @@ -469,6 +498,15 @@ def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance return not_adopted_tis + def _set_stalled_task_timeout(self, key: TaskInstanceKey, timeout: datetime.timedelta) -> None: + if timeout: + self.stalled_task_timeouts[key] = utcnow() + timeout + self.stalled_task_timeouts = dict( + sorted(self.stalled_task_timeouts.items(), key=lambda item: item[1]) + ) + else: + self.stalled_task_timeouts.pop(key, None) + def fetch_celery_task_state(async_result: AsyncResult) -> Tuple[str, Union[str, ExceptionWithTraceback], Any]: """ diff --git a/tests/executors/test_celery_executor.py b/tests/executors/test_celery_executor.py index d5bd3dbeae7a3e..a44461027cf91c 100644 --- a/tests/executors/test_celery_executor.py +++ b/tests/executors/test_celery_executor.py @@ -33,6 +33,7 @@ from celery.backends.database import DatabaseBackend from celery.contrib.testing.worker import start_worker from celery.result import AsyncResult +from freezegun import freeze_time from kombu.asynchronous import set_event_loop from parameterized import parameterized @@ -311,9 +312,9 @@ def test_try_adopt_task_instances_none(self): assert executor.try_adopt_task_instances(tis) == tis @pytest.mark.backend("mysql", "postgres") + @freeze_time("2020-01-01") def test_try_adopt_task_instances(self): start_date = timezone.utcnow() - timedelta(days=2) - queued_dttm = timezone.utcnow() - timedelta(minutes=1) try_number = 1 @@ -323,17 +324,15 @@ def test_try_adopt_task_instances(self): ti1 = TaskInstance(task=task_1, run_id=None) ti1.external_executor_id = '231' - ti1.queued_dttm = queued_dttm ti1.state = State.QUEUED ti2 = TaskInstance(task=task_2, run_id=None) ti2.external_executor_id = '232' - ti2.queued_dttm = queued_dttm ti2.state = State.QUEUED tis = [ti1, ti2] executor = celery_executor.CeleryExecutor() assert executor.running == set() - assert executor.adopted_task_timeouts == {} + assert executor.stalled_task_timeouts == {} assert executor.tasks == {} not_adopted_tis = executor.try_adopt_task_instances(tis) @@ -341,67 +340,102 @@ def test_try_adopt_task_instances(self): key_1 = TaskInstanceKey(dag.dag_id, task_1.task_id, None, try_number) key_2 = TaskInstanceKey(dag.dag_id, task_2.task_id, None, try_number) assert executor.running == {key_1, key_2} - assert dict(executor.adopted_task_timeouts) == { - key_1: queued_dttm + executor.task_adoption_timeout, - key_2: queued_dttm + executor.task_adoption_timeout, + assert dict(executor.stalled_task_timeouts) == { + key_1: timezone.utcnow() + executor.task_adoption_timeout, + key_2: timezone.utcnow() + executor.task_adoption_timeout, } assert executor.tasks == {key_1: AsyncResult("231"), key_2: AsyncResult("232")} assert not_adopted_tis == [] - @pytest.mark.backend("mysql", "postgres") - def test_check_for_stalled_adopted_tasks(self): - start_date = timezone.utcnow() - timedelta(days=2) - queued_dttm = timezone.utcnow() - timedelta(minutes=30) - - try_number = 1 + @pytest.fixture + def mock_celery_revoke(self): + with _prepare_app() as app: + app.control.revoke = mock.MagicMock() + yield app.control.revoke - with DAG("test_check_for_stalled_adopted_tasks") as dag: - task_1 = BaseOperator(task_id="task_1", start_date=start_date) - task_2 = BaseOperator(task_id="task_2", start_date=start_date) + @pytest.mark.backend("mysql", "postgres") + def test_check_for_stalled_tasks(self, create_dummy_dag, dag_maker, session, mock_celery_revoke): + create_dummy_dag(dag_id="test_clear_stalled", task_id="task1", with_dagrun_type=None) + dag_run = dag_maker.create_dagrun() - key_1 = TaskInstanceKey(dag.dag_id, task_1.task_id, "runid", try_number) - key_2 = TaskInstanceKey(dag.dag_id, task_2.task_id, "runid", try_number) + ti = dag_run.task_instances[0] + ti.state = State.QUEUED + ti.queued_dttm = timezone.utcnow() + ti.queued_by_job_id = 1 + ti.external_executor_id = '231' + session.flush() executor = celery_executor.CeleryExecutor() - executor.adopted_task_timeouts = { - key_1: queued_dttm + executor.task_adoption_timeout, - key_2: queued_dttm + executor.task_adoption_timeout, + executor.job_id = 1 + executor.stalled_task_timeouts = { + ti.key: timezone.utcnow() - timedelta(days=1), } - executor.running = {key_1, key_2} - executor.tasks = {key_1: AsyncResult("231"), key_2: AsyncResult("232")} + executor.running = {ti.key} + executor.tasks = {ti.key: AsyncResult("231")} executor.sync() - assert executor.event_buffer == {key_1: (State.FAILED, None), key_2: (State.FAILED, None)} + assert executor.event_buffer == {} assert executor.tasks == {} assert executor.running == set() - assert executor.adopted_task_timeouts == {} + assert executor.stalled_task_timeouts == {} + assert mock_celery_revoke.called_with("231") + + ti.refresh_from_db() + assert ti.state == State.SCHEDULED + assert ti.queued_by_job_id is None + assert ti.queued_dttm is None + assert ti.external_executor_id is None @pytest.mark.backend("mysql", "postgres") - def test_check_for_stalled_adopted_tasks_goes_in_ordered_fashion(self): + @freeze_time("2020-01-01") + def test_check_for_stalled_tasks_goes_in_ordered_fashion(self): start_date = timezone.utcnow() - timedelta(days=2) - queued_dttm = timezone.utcnow() - timedelta(minutes=30) - queued_dttm_2 = timezone.utcnow() - timedelta(minutes=4) - try_number = 1 - - with DAG("test_check_for_stalled_adopted_tasks") as dag: + with DAG("test_check_for_stalled_tasks_are_ordered"): task_1 = BaseOperator(task_id="task_1", start_date=start_date) task_2 = BaseOperator(task_id="task_2", start_date=start_date) - key_1 = TaskInstanceKey(dag.dag_id, task_1.task_id, "runid", try_number) - key_2 = TaskInstanceKey(dag.dag_id, task_2.task_id, "runid", try_number) + ti1 = TaskInstance(task=task_1, run_id=None) + ti1.external_executor_id = '231' + ti1.state = State.QUEUED + ti2 = TaskInstance(task=task_2, run_id=None) + ti2.external_executor_id = '232' + ti2.state = State.QUEUED executor = celery_executor.CeleryExecutor() - executor.adopted_task_timeouts = { - key_2: queued_dttm_2 + executor.task_adoption_timeout, - key_1: queued_dttm + executor.task_adoption_timeout, + executor.stalled_task_timeout = timedelta(seconds=30) + executor.queued_tasks[ti2.key] = (None, None, None, None) + executor.try_adopt_task_instances([ti1]) + with mock.patch('airflow.executors.celery_executor.send_task_to_executor') as mock_send_task: + mock_send_task.return_value = (ti2.key, None, mock.MagicMock()) + executor._process_tasks([(ti2.key, None, None, mock.MagicMock())]) + assert executor.stalled_task_timeouts == { + ti2.key: timezone.utcnow() + timedelta(seconds=30), + ti1.key: timezone.utcnow() + timedelta(seconds=600), } - executor.running = {key_1, key_2} - executor.tasks = {key_1: AsyncResult("231"), key_2: AsyncResult("232")} - executor.sync() - assert executor.event_buffer == {key_1: (State.FAILED, None)} - assert executor.tasks == {key_2: AsyncResult('232')} - assert executor.running == {key_2} - assert executor.adopted_task_timeouts == {key_2: queued_dttm_2 + executor.task_adoption_timeout} + + @pytest.mark.backend("mysql", "postgres") + def test_no_stalled_task_timeouts_when_configured(self): + start_date = timezone.utcnow() - timedelta(days=2) + + with DAG("test_check_for_stalled_tasks_are_ordered"): + task_1 = BaseOperator(task_id="task_1", start_date=start_date) + task_2 = BaseOperator(task_id="task_2", start_date=start_date) + + ti1 = TaskInstance(task=task_1, run_id=None) + ti1.external_executor_id = '231' + ti1.state = State.QUEUED + ti2 = TaskInstance(task=task_2, run_id=None) + ti2.external_executor_id = '232' + ti2.state = State.QUEUED + + executor = celery_executor.CeleryExecutor() + executor.task_adoption_timeout = timedelta(0) + executor.queued_tasks[ti2.key] = (None, None, None, None) + executor.try_adopt_task_instances([ti1]) + with mock.patch('airflow.executors.celery_executor.send_task_to_executor') as mock_send_task: + mock_send_task.return_value = (ti2.key, None, mock.MagicMock()) + executor._process_tasks([(ti2.key, None, None, mock.MagicMock())]) + assert executor.stalled_task_timeouts == {} def test_operation_timeout_config(): From 3ce6c00ed4c1505fa30ed37c1f03c93dde13b978 Mon Sep 17 00:00:00 2001 From: Chris Redekop Date: Fri, 13 May 2022 12:26:55 -0600 Subject: [PATCH 2/8] Add proper error handling --- airflow/executors/celery_executor.py | 37 +++++++++++++++++----------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/airflow/executors/celery_executor.py b/airflow/executors/celery_executor.py index e7f1e55a716cb4..6bc08b05cac558 100644 --- a/airflow/executors/celery_executor.py +++ b/airflow/executors/celery_executor.py @@ -363,19 +363,26 @@ def _check_for_stalled_tasks(self, session: Session = NEW_SESSION): "\n\t".join(repr(x) for x in timedout_keys), ) - filter_for_tis = TaskInstance.filter_for_tis(timedout_keys) - session.query(TaskInstance).filter( - filter_for_tis, TaskInstance.state == State.QUEUED, TaskInstance.queued_by_job_id == self.job_id - ).update( - { - TaskInstance.state: State.SCHEDULED, - TaskInstance.queued_dttm: None, - TaskInstance.queued_by_job_id: None, - TaskInstance.external_executor_id: None, - }, - synchronize_session=False, - ) - session.commit() + try: + filter_for_tis = TaskInstance.filter_for_tis(timedout_keys) + session.query(TaskInstance).filter( + filter_for_tis, + TaskInstance.state == State.QUEUED, + TaskInstance.queued_by_job_id == self.job_id, + ).update( + { + TaskInstance.state: State.SCHEDULED, + TaskInstance.queued_dttm: None, + TaskInstance.queued_by_job_id: None, + TaskInstance.external_executor_id: None, + }, + synchronize_session=False, + ) + session.commit() + except Exception: + self.log.exception("Error clearing stalled tasks") + session.rollback() + return for key in timedout_keys: self.stalled_task_timeouts.pop(key, None) @@ -384,8 +391,8 @@ def _check_for_stalled_tasks(self, session: Session = NEW_SESSION): if celery_async_result: try: app.control.revoke(celery_async_result.task_id) - except Exception as ex: - self.log.error("Error revoking task instance %s from celery: %s", key, ex) + except Exception: + self.log.exception("Error revoking task instance %s from celery", key) def debug_dump(self) -> None: """Called in response to SIGUSR2 by the scheduler""" From 01e778dc92f2c06766bc7bf59e8ea48062ff01f3 Mon Sep 17 00:00:00 2001 From: Chris Redekop <32752154+repl-chris@users.noreply.github.com> Date: Tue, 17 May 2022 10:05:11 -0600 Subject: [PATCH 3/8] Tweak documentation accuracy Co-authored-by: Ash Berlin-Taylor --- airflow/config_templates/default_airflow.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/config_templates/default_airflow.cfg b/airflow/config_templates/default_airflow.cfg index 3ca6528d8df604..4838ea3857b135 100644 --- a/airflow/config_templates/default_airflow.cfg +++ b/airflow/config_templates/default_airflow.cfg @@ -895,7 +895,7 @@ task_track_started = True task_adoption_timeout = 600 # Time in seconds after which tasks queued in celery are assumed to be stalled, and are automatically -# cleared. Adopted tasks will instead use the ``task_adoption_timeout`` setting if specified. +# rescheduled. Adopted tasks will instead use the ``task_adoption_timeout`` setting if specified. # When set to 0, automatic clearing of stalled tasks is disabled. stalled_task_timeout = 0 From a22436ca92fa01755a6fd11f0fdb0dc4040cc0e1 Mon Sep 17 00:00:00 2001 From: Chris Redekop <32752154+repl-chris@users.noreply.github.com> Date: Tue, 17 May 2022 10:11:44 -0600 Subject: [PATCH 4/8] Tweak error message accuracy Co-authored-by: Ash Berlin-Taylor --- airflow/executors/celery_executor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/executors/celery_executor.py b/airflow/executors/celery_executor.py index 6bc08b05cac558..076095155a553e 100644 --- a/airflow/executors/celery_executor.py +++ b/airflow/executors/celery_executor.py @@ -357,7 +357,7 @@ def _check_for_stalled_tasks(self, session: Session = NEW_SESSION): self.log.error( "Tasks were still pending after configured timeout (adopted: %s, all: %s), " - "assuming they never made it to celery and clearing:\n\t%s", + "assuming they never made it to celery and sending back to the scheduler:\n\t%s", self.task_adoption_timeout, self.stalled_task_timeout, "\n\t".join(repr(x) for x in timedout_keys), From 8599751a4812cb0701774226096b1be42dc002ed Mon Sep 17 00:00:00 2001 From: Chris Redekop Date: Tue, 17 May 2022 13:06:04 -0600 Subject: [PATCH 5/8] split tracking and processing of adopted and 'native' stalled tasks for clarity --- airflow/executors/celery_executor.py | 109 +++++++++++++++--------- tests/executors/test_celery_executor.py | 43 +++++++++- 2 files changed, 110 insertions(+), 42 deletions(-) diff --git a/airflow/executors/celery_executor.py b/airflow/executors/celery_executor.py index 076095155a553e..def2d4336c8355 100644 --- a/airflow/executors/celery_executor.py +++ b/airflow/executors/celery_executor.py @@ -31,6 +31,7 @@ import traceback from collections import Counter from concurrent.futures import ProcessPoolExecutor +from enum import Enum from multiprocessing import cpu_count from typing import Any, Dict, List, Mapping, MutableMapping, Optional, Set, Tuple, Union @@ -209,6 +210,11 @@ def on_celery_import_modules(*args, **kwargs): pass +class _CeleryPendingTaskTimeoutType(Enum): + ADOPTED = 1 + STALLED = 2 + + class CeleryExecutor(BaseExecutor): """ CeleryExecutor is recommended for production use of Airflow. It allows @@ -233,12 +239,14 @@ def __init__(self): self.bulk_state_fetcher = BulkStateFetcher(self._sync_parallelism) self.tasks = {} self.stalled_task_timeouts: Dict[TaskInstanceKey, datetime.datetime] = {} - self.task_adoption_timeout = datetime.timedelta( - seconds=conf.getint('celery', 'task_adoption_timeout', fallback=600) - ) self.stalled_task_timeout = datetime.timedelta( seconds=conf.getint('celery', 'stalled_task_timeout', fallback=0) ) + self.adopted_task_timeouts: Dict[TaskInstanceKey, datetime.datetime] = {} + self.task_adoption_timeout = ( + datetime.timedelta(seconds=conf.getint('celery', 'task_adoption_timeout', fallback=600)) + or self.stalled_task_timeout + ) self.task_publish_retries: Counter[TaskInstanceKey] = Counter() self.task_publish_max_retries = conf.getint('celery', 'task_publish_max_retries', fallback=3) @@ -289,7 +297,7 @@ def _process_tasks(self, task_tuples: List[TaskTuple]) -> None: result.backend = cached_celery_backend self.running.add(key) self.tasks[key] = result - self._set_stalled_task_timeout(key, self.stalled_task_timeout) + self._set_celery_pending_task_timeout(key, _CeleryPendingTaskTimeoutType.STALLED) # Store the Celery task_id in the event buffer. This will get "overwritten" if the task # has another event, but that is fine, because the only other events are success/failed at @@ -320,14 +328,38 @@ def sync(self) -> None: self.log.debug("No task to query celery, skipping sync") return self.update_all_task_states() + self._check_for_timedout_adopted_tasks() self._check_for_stalled_tasks() - @provide_session - def _check_for_stalled_tasks(self, session: Session = NEW_SESSION): + def _check_for_timedout_adopted_tasks(self) -> None: + timedout_keys = self._get_timedout_ti_keys(self.adopted_task_timeouts) + if timedout_keys: + self.log.error( + "Adopted tasks were still pending after %s, assuming they never made it to celery " + "and sending back to the scheduler:\n\t%s", + self.task_adoption_timeout, + "\n\t".join(repr(x) for x in timedout_keys), + ) + self._send_stalled_tis_back_to_scheduler(timedout_keys) + + def _check_for_stalled_tasks(self) -> None: + timedout_keys = self._get_timedout_ti_keys(self.stalled_task_timeouts) + if timedout_keys: + self.log.error( + "Tasks were still pending after %s, assuming they never made it to celery " + "and sending back to the scheduler:\n\t%s", + self.stalled_task_timeout, + "\n\t".join(repr(x) for x in timedout_keys), + ) + self._send_stalled_tis_back_to_scheduler(timedout_keys) + + def _get_timedout_ti_keys( + self, task_timeouts: Dict[TaskInstanceKey, datetime.datetime] + ) -> List[TaskInstanceKey]: """ - Check to see if any of our tasks have not progressed in the expected - time. This can happen for few different reasons, usually related to - race conditions while shutting down schedulers and celery workers. + These timeouts exist to check to see if any of our tasks have not progressed + in the expected time. This can happen for few different reasons, usually related + to race conditions while shutting down schedulers and celery workers. It is, of course, always possible that these tasks are not actually stalled - they could just be waiting in a long celery queue. @@ -336,7 +368,7 @@ def _check_for_stalled_tasks(self, session: Session = NEW_SESSION): """ now = utcnow() timedout_keys = [] - for key, stalled_after in self.stalled_task_timeouts.items(): + for key, stalled_after in task_timeouts.items(): if stalled_after > now: # Since items are stored sorted, if we get to a stalled_after # in the future then we can stop @@ -351,22 +383,15 @@ def _check_for_stalled_tasks(self, session: Session = NEW_SESSION): # hasn't started yet -- better cancel it and let the scheduler # re-queue rather than have this task risk stalling for ever timedout_keys.append(key) + return timedout_keys - if not timedout_keys: - return - - self.log.error( - "Tasks were still pending after configured timeout (adopted: %s, all: %s), " - "assuming they never made it to celery and sending back to the scheduler:\n\t%s", - self.task_adoption_timeout, - self.stalled_task_timeout, - "\n\t".join(repr(x) for x in timedout_keys), - ) - + @provide_session + def _send_stalled_tis_back_to_scheduler( + self, keys: List[TaskInstanceKey], session: Session = NEW_SESSION + ) -> None: try: - filter_for_tis = TaskInstance.filter_for_tis(timedout_keys) session.query(TaskInstance).filter( - filter_for_tis, + TaskInstance.filter_for_tis(keys), TaskInstance.state == State.QUEUED, TaskInstance.queued_by_job_id == self.job_id, ).update( @@ -380,19 +405,19 @@ def _check_for_stalled_tasks(self, session: Session = NEW_SESSION): ) session.commit() except Exception: - self.log.exception("Error clearing stalled tasks") + self.log.exception("Error sending tasks back to scheduler") session.rollback() return - for key in timedout_keys: - self.stalled_task_timeouts.pop(key, None) + for key in keys: + self._set_celery_pending_task_timeout(key, None) self.running.discard(key) celery_async_result = self.tasks.pop(key, None) if celery_async_result: try: app.control.revoke(celery_async_result.task_id) - except Exception: - self.log.exception("Error revoking task instance %s from celery", key) + except Exception as ex: + self.log.error("Error revoking task instance %s from celery: %s", key, ex) def debug_dump(self) -> None: """Called in response to SIGUSR2 by the scheduler""" @@ -400,6 +425,11 @@ def debug_dump(self) -> None: self.log.info( "executor.tasks (%d)\n\t%s", len(self.tasks), "\n\t".join(map(repr, self.tasks.items())) ) + self.log.info( + "executor.adopted_task_timeouts (%d)\n\t%s", + len(self.adopted_task_timeouts), + "\n\t".join(map(repr, self.adopted_task_timeouts.items())), + ) self.log.info( "executor.stalled_task_timeouts (%d)\n\t%s", len(self.stalled_task_timeouts), @@ -420,7 +450,7 @@ def update_all_task_states(self) -> None: def change_state(self, key: TaskInstanceKey, state: str, info=None) -> None: super().change_state(key, state, info) self.tasks.pop(key, None) - self.stalled_task_timeouts.pop(key, None) + self._set_celery_pending_task_timeout(key, None) def update_task_state(self, key: TaskInstanceKey, state: str, info: Any) -> None: """Updates state of a single task.""" @@ -431,7 +461,7 @@ def update_task_state(self, key: TaskInstanceKey, state: str, info: Any) -> None self.fail(key, info) elif state == celery_states.STARTED: # It's now actually running, so we know it made it to celery okay! - self.stalled_task_timeouts.pop(key, None) + self._set_celery_pending_task_timeout(key, None) elif state == celery_states.PENDING: pass else: @@ -491,7 +521,7 @@ def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance # Set the correct elements of the state dicts, then update this # like we just queried it. - self._set_stalled_task_timeout(ti.key, self.task_adoption_timeout or self.stalled_task_timeout) + self._set_celery_pending_task_timeout(ti.key, _CeleryPendingTaskTimeoutType.ADOPTED) self.tasks[ti.key] = result self.running.add(ti.key) self.update_task_state(ti.key, state, info) @@ -505,14 +535,15 @@ def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance return not_adopted_tis - def _set_stalled_task_timeout(self, key: TaskInstanceKey, timeout: datetime.timedelta) -> None: - if timeout: - self.stalled_task_timeouts[key] = utcnow() + timeout - self.stalled_task_timeouts = dict( - sorted(self.stalled_task_timeouts.items(), key=lambda item: item[1]) - ) - else: - self.stalled_task_timeouts.pop(key, None) + def _set_celery_pending_task_timeout( + self, key: TaskInstanceKey, timeout_type: Optional[_CeleryPendingTaskTimeoutType] + ) -> None: + self.adopted_task_timeouts.pop(key, None) + self.stalled_task_timeouts.pop(key, None) + if timeout_type == _CeleryPendingTaskTimeoutType.ADOPTED and self.task_adoption_timeout: + self.adopted_task_timeouts[key] = utcnow() + self.task_adoption_timeout + elif timeout_type == _CeleryPendingTaskTimeoutType.STALLED and self.stalled_task_timeout: + self.stalled_task_timeouts[key] = utcnow() + self.stalled_task_timeout def fetch_celery_task_state(async_result: AsyncResult) -> Tuple[str, Union[str, ExceptionWithTraceback], Any]: diff --git a/tests/executors/test_celery_executor.py b/tests/executors/test_celery_executor.py index a44461027cf91c..9d9485e3ce3dad 100644 --- a/tests/executors/test_celery_executor.py +++ b/tests/executors/test_celery_executor.py @@ -332,6 +332,7 @@ def test_try_adopt_task_instances(self): tis = [ti1, ti2] executor = celery_executor.CeleryExecutor() assert executor.running == set() + assert executor.adopted_task_timeouts == {} assert executor.stalled_task_timeouts == {} assert executor.tasks == {} @@ -340,10 +341,11 @@ def test_try_adopt_task_instances(self): key_1 = TaskInstanceKey(dag.dag_id, task_1.task_id, None, try_number) key_2 = TaskInstanceKey(dag.dag_id, task_2.task_id, None, try_number) assert executor.running == {key_1, key_2} - assert dict(executor.stalled_task_timeouts) == { + assert executor.adopted_task_timeouts == { key_1: timezone.utcnow() + executor.task_adoption_timeout, key_2: timezone.utcnow() + executor.task_adoption_timeout, } + assert executor.stalled_task_timeouts == {} assert executor.tasks == {key_1: AsyncResult("231"), key_2: AsyncResult("232")} assert not_adopted_tis == [] @@ -353,6 +355,38 @@ def mock_celery_revoke(self): app.control.revoke = mock.MagicMock() yield app.control.revoke + @pytest.mark.backend("mysql", "postgres") + def test_check_for_timedout_adopted_tasks(self, create_dummy_dag, dag_maker, session, mock_celery_revoke): + create_dummy_dag(dag_id="test_clear_stalled", task_id="task1", with_dagrun_type=None) + dag_run = dag_maker.create_dagrun() + + ti = dag_run.task_instances[0] + ti.state = State.QUEUED + ti.queued_dttm = timezone.utcnow() + ti.queued_by_job_id = 1 + ti.external_executor_id = '231' + session.flush() + + executor = celery_executor.CeleryExecutor() + executor.job_id = 1 + executor.adopted_task_timeouts = { + ti.key: timezone.utcnow() - timedelta(days=1), + } + executor.running = {ti.key} + executor.tasks = {ti.key: AsyncResult("231")} + executor.sync() + assert executor.event_buffer == {} + assert executor.tasks == {} + assert executor.running == set() + assert executor.adopted_task_timeouts == {} + assert mock_celery_revoke.called_with("231") + + ti.refresh_from_db() + assert ti.state == State.SCHEDULED + assert ti.queued_by_job_id is None + assert ti.queued_dttm is None + assert ti.external_executor_id is None + @pytest.mark.backend("mysql", "postgres") def test_check_for_stalled_tasks(self, create_dummy_dag, dag_maker, session, mock_celery_revoke): create_dummy_dag(dag_id="test_clear_stalled", task_id="task1", with_dagrun_type=None) @@ -387,7 +421,7 @@ def test_check_for_stalled_tasks(self, create_dummy_dag, dag_maker, session, moc @pytest.mark.backend("mysql", "postgres") @freeze_time("2020-01-01") - def test_check_for_stalled_tasks_goes_in_ordered_fashion(self): + def test_pending_tasks_timeout_with_appropriate_config_setting(self): start_date = timezone.utcnow() - timedelta(days=2) with DAG("test_check_for_stalled_tasks_are_ordered"): @@ -410,11 +444,13 @@ def test_check_for_stalled_tasks_goes_in_ordered_fashion(self): executor._process_tasks([(ti2.key, None, None, mock.MagicMock())]) assert executor.stalled_task_timeouts == { ti2.key: timezone.utcnow() + timedelta(seconds=30), + } + assert executor.adopted_task_timeouts == { ti1.key: timezone.utcnow() + timedelta(seconds=600), } @pytest.mark.backend("mysql", "postgres") - def test_no_stalled_task_timeouts_when_configured(self): + def test_no_pending_task_timeouts_when_configured(self): start_date = timezone.utcnow() - timedelta(days=2) with DAG("test_check_for_stalled_tasks_are_ordered"): @@ -435,6 +471,7 @@ def test_no_stalled_task_timeouts_when_configured(self): with mock.patch('airflow.executors.celery_executor.send_task_to_executor') as mock_send_task: mock_send_task.return_value = (ti2.key, None, mock.MagicMock()) executor._process_tasks([(ti2.key, None, None, mock.MagicMock())]) + assert executor.adopted_task_timeouts == {} assert executor.stalled_task_timeouts == {} From 7eeeb92d0cf5d5e38ec2d0c470661cdee1ba10b5 Mon Sep 17 00:00:00 2001 From: Chris Redekop Date: Tue, 17 May 2022 17:34:09 -0600 Subject: [PATCH 6/8] Fix doc inconsistency --- airflow/config_templates/config.yml | 4 ++-- airflow/config_templates/default_airflow.cfg | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml index 7cff1bc4032dcd..b5ca6a7d80c63a 100644 --- a/airflow/config_templates/config.yml +++ b/airflow/config_templates/config.yml @@ -1769,7 +1769,7 @@ - name: task_adoption_timeout description: | Time in seconds after which adopted tasks which are queued in celery are assumed to be stalled, - and are automatically cleared. This setting does the same thing as ``stalled_task_timeout`` but + and are automatically rescheduled. This setting does the same thing as ``stalled_task_timeout`` but applies specifically to adopted tasks only. When set to 0, the ``stalled_task_timeout`` setting also applies to adopted tasks. version_added: 2.0.0 @@ -1779,7 +1779,7 @@ - name: stalled_task_timeout description: | Time in seconds after which tasks queued in celery are assumed to be stalled, and are automatically - cleared. Adopted tasks will instead use the ``task_adoption_timeout`` setting if specified. + rescheduled. Adopted tasks will instead use the ``task_adoption_timeout`` setting if specified. When set to 0, automatic clearing of stalled tasks is disabled. version_added: 2.3.1 type: integer diff --git a/airflow/config_templates/default_airflow.cfg b/airflow/config_templates/default_airflow.cfg index 4838ea3857b135..07af45aecb683f 100644 --- a/airflow/config_templates/default_airflow.cfg +++ b/airflow/config_templates/default_airflow.cfg @@ -889,7 +889,7 @@ operation_timeout = 1.0 task_track_started = True # Time in seconds after which adopted tasks which are queued in celery are assumed to be stalled, -# and are automatically cleared. This setting does the same thing as ``stalled_task_timeout`` but +# and are automatically rescheduled. This setting does the same thing as ``stalled_task_timeout`` but # applies specifically to adopted tasks only. When set to 0, the ``stalled_task_timeout`` setting # also applies to adopted tasks. task_adoption_timeout = 600 From 2ea27edfb540e6fa3298158cd4ef875c01d47767 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Fri, 20 May 2022 12:19:37 +0100 Subject: [PATCH 7/8] Update airflow/executors/celery_executor.py --- airflow/executors/celery_executor.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/airflow/executors/celery_executor.py b/airflow/executors/celery_executor.py index def2d4336c8355..ed08da4672b5e4 100644 --- a/airflow/executors/celery_executor.py +++ b/airflow/executors/celery_executor.py @@ -536,6 +536,11 @@ def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance return not_adopted_tis def _set_celery_pending_task_timeout( + """ + We use the fact that dicts maintain insertion order, and the the timeout for a + task is always "now + delta" to maintain the property that oldest item = first to + time out. + """ self, key: TaskInstanceKey, timeout_type: Optional[_CeleryPendingTaskTimeoutType] ) -> None: self.adopted_task_timeouts.pop(key, None) From 44fcc4ae821e01591652e6c0cd190a8b9d42d0de Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Fri, 20 May 2022 14:51:47 +0100 Subject: [PATCH 8/8] Update airflow/executors/celery_executor.py :facepalm: --- airflow/executors/celery_executor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/airflow/executors/celery_executor.py b/airflow/executors/celery_executor.py index ed08da4672b5e4..7b4c04e225a75a 100644 --- a/airflow/executors/celery_executor.py +++ b/airflow/executors/celery_executor.py @@ -536,13 +536,13 @@ def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance return not_adopted_tis def _set_celery_pending_task_timeout( + self, key: TaskInstanceKey, timeout_type: Optional[_CeleryPendingTaskTimeoutType] + ) -> None: """ We use the fact that dicts maintain insertion order, and the the timeout for a task is always "now + delta" to maintain the property that oldest item = first to time out. """ - self, key: TaskInstanceKey, timeout_type: Optional[_CeleryPendingTaskTimeoutType] - ) -> None: self.adopted_task_timeouts.pop(key, None) self.stalled_task_timeouts.pop(key, None) if timeout_type == _CeleryPendingTaskTimeoutType.ADOPTED and self.task_adoption_timeout: