From d534bdf8fa47fd0bdf410d85326c92c9c73ef3a2 Mon Sep 17 00:00:00 2001 From: Chris Redekop <32752154+repl-chris@users.noreply.github.com> Date: Fri, 20 May 2022 08:40:58 -0600 Subject: [PATCH] Automatically reschedule stalled queued tasks in CeleryExecutor (v2) (#23690) Celery can lose tasks on worker shutdown, causing airflow to just wait on them indefinitely (may be related to celery/celery#7266). This PR expands the "stalled tasks" functionality which is already in place for adopted tasks, and adds the ability to apply it to all tasks such that these lost/hung tasks can be automatically recovered and queued up again. (cherry picked from commit baae70c88ed45d4b45e64754cb3decb99472c601) --- airflow/config_templates/config.yml | 15 +- airflow/config_templates/default_airflow.cfg | 11 +- airflow/executors/celery_executor.py | 143 ++++++++++++++---- tests/executors/test_celery_executor.py | 151 ++++++++++++++----- 4 files changed, 245 insertions(+), 75 deletions(-) diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml index bc8f42035b9c8..3d89a3004177e 100644 --- a/airflow/config_templates/config.yml +++ b/airflow/config_templates/config.yml @@ -1767,12 +1767,23 @@ default: "True" - name: task_adoption_timeout description: | - Time in seconds after which Adopted tasks are cleared by CeleryExecutor. This is helpful to clear - stalled tasks. + Time in seconds after which adopted tasks which are queued in celery are assumed to be stalled, + and are automatically rescheduled. This setting does the same thing as ``stalled_task_timeout`` but + applies specifically to adopted tasks only. When set to 0, the ``stalled_task_timeout`` setting + also applies to adopted tasks. version_added: 2.0.0 type: integer example: ~ default: "600" + - name: stalled_task_timeout + description: | + Time in seconds after which tasks queued in celery are assumed to be stalled, and are automatically + rescheduled. Adopted tasks will instead use the ``task_adoption_timeout`` setting if specified. + When set to 0, automatic clearing of stalled tasks is disabled. + version_added: 2.3.1 + type: integer + example: ~ + default: "0" - name: task_publish_max_retries description: | The Maximum number of retries for publishing task messages to the broker when failing diff --git a/airflow/config_templates/default_airflow.cfg b/airflow/config_templates/default_airflow.cfg index a6583d8f47293..e177751e899e3 100644 --- a/airflow/config_templates/default_airflow.cfg +++ b/airflow/config_templates/default_airflow.cfg @@ -887,10 +887,17 @@ operation_timeout = 1.0 # or run in HA mode, it can adopt the orphan tasks launched by previous SchedulerJob. task_track_started = True -# Time in seconds after which Adopted tasks are cleared by CeleryExecutor. This is helpful to clear -# stalled tasks. +# Time in seconds after which adopted tasks which are queued in celery are assumed to be stalled, +# and are automatically rescheduled. This setting does the same thing as ``stalled_task_timeout`` but +# applies specifically to adopted tasks only. When set to 0, the ``stalled_task_timeout`` setting +# also applies to adopted tasks. task_adoption_timeout = 600 +# Time in seconds after which tasks queued in celery are assumed to be stalled, and are automatically +# rescheduled. Adopted tasks will instead use the ``task_adoption_timeout`` setting if specified. +# When set to 0, automatic clearing of stalled tasks is disabled. +stalled_task_timeout = 0 + # The Maximum number of retries for publishing task messages to the broker when failing # due to ``AirflowTaskTimeout`` error before giving up and marking Task as failed. task_publish_max_retries = 3 diff --git a/airflow/executors/celery_executor.py b/airflow/executors/celery_executor.py index 752244d5528c6..7b4c04e225a75 100644 --- a/airflow/executors/celery_executor.py +++ b/airflow/executors/celery_executor.py @@ -29,8 +29,9 @@ import subprocess import time import traceback -from collections import Counter, OrderedDict +from collections import Counter from concurrent.futures import ProcessPoolExecutor +from enum import Enum from multiprocessing import cpu_count from typing import Any, Dict, List, Mapping, MutableMapping, Optional, Set, Tuple, Union @@ -40,6 +41,7 @@ from celery.result import AsyncResult from celery.signals import import_modules as celery_import_modules from setproctitle import setproctitle +from sqlalchemy.orm.session import Session import airflow.settings as settings from airflow.config_templates.default_celery import DEFAULT_CELERY_CONFIG @@ -50,6 +52,7 @@ from airflow.stats import Stats from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.net import get_hostname +from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.state import State from airflow.utils.timeout import timeout from airflow.utils.timezone import utcnow @@ -207,6 +210,11 @@ def on_celery_import_modules(*args, **kwargs): pass +class _CeleryPendingTaskTimeoutType(Enum): + ADOPTED = 1 + STALLED = 2 + + class CeleryExecutor(BaseExecutor): """ CeleryExecutor is recommended for production use of Airflow. It allows @@ -230,10 +238,14 @@ def __init__(self): self._sync_parallelism = max(1, cpu_count() - 1) self.bulk_state_fetcher = BulkStateFetcher(self._sync_parallelism) self.tasks = {} - # Mapping of tasks we've adopted, ordered by the earliest date they timeout - self.adopted_task_timeouts: Dict[TaskInstanceKey, datetime.datetime] = OrderedDict() - self.task_adoption_timeout = datetime.timedelta( - seconds=conf.getint('celery', 'task_adoption_timeout', fallback=600) + self.stalled_task_timeouts: Dict[TaskInstanceKey, datetime.datetime] = {} + self.stalled_task_timeout = datetime.timedelta( + seconds=conf.getint('celery', 'stalled_task_timeout', fallback=0) + ) + self.adopted_task_timeouts: Dict[TaskInstanceKey, datetime.datetime] = {} + self.task_adoption_timeout = ( + datetime.timedelta(seconds=conf.getint('celery', 'task_adoption_timeout', fallback=600)) + or self.stalled_task_timeout ) self.task_publish_retries: Counter[TaskInstanceKey] = Counter() self.task_publish_max_retries = conf.getint('celery', 'task_publish_max_retries', fallback=3) @@ -285,6 +297,7 @@ def _process_tasks(self, task_tuples: List[TaskTuple]) -> None: result.backend = cached_celery_backend self.running.add(key) self.tasks[key] = result + self._set_celery_pending_task_timeout(key, _CeleryPendingTaskTimeoutType.STALLED) # Store the Celery task_id in the event buffer. This will get "overwritten" if the task # has another event, but that is fine, because the only other events are success/failed at @@ -315,25 +328,47 @@ def sync(self) -> None: self.log.debug("No task to query celery, skipping sync") return self.update_all_task_states() + self._check_for_timedout_adopted_tasks() + self._check_for_stalled_tasks() + + def _check_for_timedout_adopted_tasks(self) -> None: + timedout_keys = self._get_timedout_ti_keys(self.adopted_task_timeouts) + if timedout_keys: + self.log.error( + "Adopted tasks were still pending after %s, assuming they never made it to celery " + "and sending back to the scheduler:\n\t%s", + self.task_adoption_timeout, + "\n\t".join(repr(x) for x in timedout_keys), + ) + self._send_stalled_tis_back_to_scheduler(timedout_keys) - if self.adopted_task_timeouts: - self._check_for_stalled_adopted_tasks() + def _check_for_stalled_tasks(self) -> None: + timedout_keys = self._get_timedout_ti_keys(self.stalled_task_timeouts) + if timedout_keys: + self.log.error( + "Tasks were still pending after %s, assuming they never made it to celery " + "and sending back to the scheduler:\n\t%s", + self.stalled_task_timeout, + "\n\t".join(repr(x) for x in timedout_keys), + ) + self._send_stalled_tis_back_to_scheduler(timedout_keys) - def _check_for_stalled_adopted_tasks(self): + def _get_timedout_ti_keys( + self, task_timeouts: Dict[TaskInstanceKey, datetime.datetime] + ) -> List[TaskInstanceKey]: """ - See if any of the tasks we adopted from another Executor run have not - progressed after the configured timeout. - - If they haven't, they likely never made it to Celery, and we should - just resend them. We do that by clearing the state and letting the - normal scheduler loop deal with that + These timeouts exist to check to see if any of our tasks have not progressed + in the expected time. This can happen for few different reasons, usually related + to race conditions while shutting down schedulers and celery workers. + + It is, of course, always possible that these tasks are not actually + stalled - they could just be waiting in a long celery queue. + Unfortunately there's no way for us to know for sure, so we'll just + reschedule them and let the normal scheduler loop requeue them. """ now = utcnow() - - sorted_adopted_task_timeouts = sorted(self.adopted_task_timeouts.items(), key=lambda k: k[1]) - timedout_keys = [] - for key, stalled_after in sorted_adopted_task_timeouts: + for key, stalled_after in task_timeouts.items(): if stalled_after > now: # Since items are stored sorted, if we get to a stalled_after # in the future then we can stop @@ -343,20 +378,46 @@ def _check_for_stalled_adopted_tasks(self): # already finished, then it will be removed from this list -- so # the only time it's still in this list is when it a) never made it # to celery in the first place (i.e. race condition somewhere in - # the dying executor) or b) a really long celery queue and it just + # the dying executor), b) celery lost the task before execution + # started, or c) a really long celery queue and it just # hasn't started yet -- better cancel it and let the scheduler # re-queue rather than have this task risk stalling for ever timedout_keys.append(key) + return timedout_keys - if timedout_keys: - self.log.error( - "Adopted tasks were still pending after %s, assuming they never made it to celery and " - "clearing:\n\t%s", - self.task_adoption_timeout, - "\n\t".join(repr(x) for x in timedout_keys), + @provide_session + def _send_stalled_tis_back_to_scheduler( + self, keys: List[TaskInstanceKey], session: Session = NEW_SESSION + ) -> None: + try: + session.query(TaskInstance).filter( + TaskInstance.filter_for_tis(keys), + TaskInstance.state == State.QUEUED, + TaskInstance.queued_by_job_id == self.job_id, + ).update( + { + TaskInstance.state: State.SCHEDULED, + TaskInstance.queued_dttm: None, + TaskInstance.queued_by_job_id: None, + TaskInstance.external_executor_id: None, + }, + synchronize_session=False, ) - for key in timedout_keys: - self.change_state(key, State.FAILED) + session.commit() + except Exception: + self.log.exception("Error sending tasks back to scheduler") + session.rollback() + return + + for key in keys: + self._set_celery_pending_task_timeout(key, None) + self.running.discard(key) + celery_async_result = self.tasks.pop(key, None) + if celery_async_result: + try: + app.control.revoke(celery_async_result.task_id) + except Exception as ex: + self.log.error("Error revoking task instance %s from celery: %s", key, ex) def debug_dump(self) -> None: """Called in response to SIGUSR2 by the scheduler""" @@ -369,6 +430,11 @@ def debug_dump(self) -> None: len(self.adopted_task_timeouts), "\n\t".join(map(repr, self.adopted_task_timeouts.items())), ) + self.log.info( + "executor.stalled_task_timeouts (%d)\n\t%s", + len(self.stalled_task_timeouts), + "\n\t".join(map(repr, self.stalled_task_timeouts.items())), + ) def update_all_task_states(self) -> None: """Updates states of the tasks.""" @@ -384,7 +450,7 @@ def update_all_task_states(self) -> None: def change_state(self, key: TaskInstanceKey, state: str, info=None) -> None: super().change_state(key, state, info) self.tasks.pop(key, None) - self.adopted_task_timeouts.pop(key, None) + self._set_celery_pending_task_timeout(key, None) def update_task_state(self, key: TaskInstanceKey, state: str, info: Any) -> None: """Updates state of a single task.""" @@ -394,8 +460,8 @@ def update_task_state(self, key: TaskInstanceKey, state: str, info: Any) -> None elif state in (celery_states.FAILURE, celery_states.REVOKED): self.fail(key, info) elif state == celery_states.STARTED: - # It's now actually running, so know it made it to celery okay! - self.adopted_task_timeouts.pop(key, None) + # It's now actually running, so we know it made it to celery okay! + self._set_celery_pending_task_timeout(key, None) elif state == celery_states.PENDING: pass else: @@ -455,7 +521,7 @@ def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance # Set the correct elements of the state dicts, then update this # like we just queried it. - self.adopted_task_timeouts[ti.key] = ti.queued_dttm + self.task_adoption_timeout + self._set_celery_pending_task_timeout(ti.key, _CeleryPendingTaskTimeoutType.ADOPTED) self.tasks[ti.key] = result self.running.add(ti.key) self.update_task_state(ti.key, state, info) @@ -469,6 +535,21 @@ def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance return not_adopted_tis + def _set_celery_pending_task_timeout( + self, key: TaskInstanceKey, timeout_type: Optional[_CeleryPendingTaskTimeoutType] + ) -> None: + """ + We use the fact that dicts maintain insertion order, and the the timeout for a + task is always "now + delta" to maintain the property that oldest item = first to + time out. + """ + self.adopted_task_timeouts.pop(key, None) + self.stalled_task_timeouts.pop(key, None) + if timeout_type == _CeleryPendingTaskTimeoutType.ADOPTED and self.task_adoption_timeout: + self.adopted_task_timeouts[key] = utcnow() + self.task_adoption_timeout + elif timeout_type == _CeleryPendingTaskTimeoutType.STALLED and self.stalled_task_timeout: + self.stalled_task_timeouts[key] = utcnow() + self.stalled_task_timeout + def fetch_celery_task_state(async_result: AsyncResult) -> Tuple[str, Union[str, ExceptionWithTraceback], Any]: """ diff --git a/tests/executors/test_celery_executor.py b/tests/executors/test_celery_executor.py index d5bd3dbeae7a3..9d9485e3ce3da 100644 --- a/tests/executors/test_celery_executor.py +++ b/tests/executors/test_celery_executor.py @@ -33,6 +33,7 @@ from celery.backends.database import DatabaseBackend from celery.contrib.testing.worker import start_worker from celery.result import AsyncResult +from freezegun import freeze_time from kombu.asynchronous import set_event_loop from parameterized import parameterized @@ -311,9 +312,9 @@ def test_try_adopt_task_instances_none(self): assert executor.try_adopt_task_instances(tis) == tis @pytest.mark.backend("mysql", "postgres") + @freeze_time("2020-01-01") def test_try_adopt_task_instances(self): start_date = timezone.utcnow() - timedelta(days=2) - queued_dttm = timezone.utcnow() - timedelta(minutes=1) try_number = 1 @@ -323,17 +324,16 @@ def test_try_adopt_task_instances(self): ti1 = TaskInstance(task=task_1, run_id=None) ti1.external_executor_id = '231' - ti1.queued_dttm = queued_dttm ti1.state = State.QUEUED ti2 = TaskInstance(task=task_2, run_id=None) ti2.external_executor_id = '232' - ti2.queued_dttm = queued_dttm ti2.state = State.QUEUED tis = [ti1, ti2] executor = celery_executor.CeleryExecutor() assert executor.running == set() assert executor.adopted_task_timeouts == {} + assert executor.stalled_task_timeouts == {} assert executor.tasks == {} not_adopted_tis = executor.try_adopt_task_instances(tis) @@ -341,67 +341,138 @@ def test_try_adopt_task_instances(self): key_1 = TaskInstanceKey(dag.dag_id, task_1.task_id, None, try_number) key_2 = TaskInstanceKey(dag.dag_id, task_2.task_id, None, try_number) assert executor.running == {key_1, key_2} - assert dict(executor.adopted_task_timeouts) == { - key_1: queued_dttm + executor.task_adoption_timeout, - key_2: queued_dttm + executor.task_adoption_timeout, + assert executor.adopted_task_timeouts == { + key_1: timezone.utcnow() + executor.task_adoption_timeout, + key_2: timezone.utcnow() + executor.task_adoption_timeout, } + assert executor.stalled_task_timeouts == {} assert executor.tasks == {key_1: AsyncResult("231"), key_2: AsyncResult("232")} assert not_adopted_tis == [] - @pytest.mark.backend("mysql", "postgres") - def test_check_for_stalled_adopted_tasks(self): - start_date = timezone.utcnow() - timedelta(days=2) - queued_dttm = timezone.utcnow() - timedelta(minutes=30) - - try_number = 1 + @pytest.fixture + def mock_celery_revoke(self): + with _prepare_app() as app: + app.control.revoke = mock.MagicMock() + yield app.control.revoke - with DAG("test_check_for_stalled_adopted_tasks") as dag: - task_1 = BaseOperator(task_id="task_1", start_date=start_date) - task_2 = BaseOperator(task_id="task_2", start_date=start_date) + @pytest.mark.backend("mysql", "postgres") + def test_check_for_timedout_adopted_tasks(self, create_dummy_dag, dag_maker, session, mock_celery_revoke): + create_dummy_dag(dag_id="test_clear_stalled", task_id="task1", with_dagrun_type=None) + dag_run = dag_maker.create_dagrun() - key_1 = TaskInstanceKey(dag.dag_id, task_1.task_id, "runid", try_number) - key_2 = TaskInstanceKey(dag.dag_id, task_2.task_id, "runid", try_number) + ti = dag_run.task_instances[0] + ti.state = State.QUEUED + ti.queued_dttm = timezone.utcnow() + ti.queued_by_job_id = 1 + ti.external_executor_id = '231' + session.flush() executor = celery_executor.CeleryExecutor() + executor.job_id = 1 executor.adopted_task_timeouts = { - key_1: queued_dttm + executor.task_adoption_timeout, - key_2: queued_dttm + executor.task_adoption_timeout, + ti.key: timezone.utcnow() - timedelta(days=1), } - executor.running = {key_1, key_2} - executor.tasks = {key_1: AsyncResult("231"), key_2: AsyncResult("232")} + executor.running = {ti.key} + executor.tasks = {ti.key: AsyncResult("231")} executor.sync() - assert executor.event_buffer == {key_1: (State.FAILED, None), key_2: (State.FAILED, None)} + assert executor.event_buffer == {} assert executor.tasks == {} assert executor.running == set() assert executor.adopted_task_timeouts == {} + assert mock_celery_revoke.called_with("231") + + ti.refresh_from_db() + assert ti.state == State.SCHEDULED + assert ti.queued_by_job_id is None + assert ti.queued_dttm is None + assert ti.external_executor_id is None @pytest.mark.backend("mysql", "postgres") - def test_check_for_stalled_adopted_tasks_goes_in_ordered_fashion(self): - start_date = timezone.utcnow() - timedelta(days=2) - queued_dttm = timezone.utcnow() - timedelta(minutes=30) - queued_dttm_2 = timezone.utcnow() - timedelta(minutes=4) + def test_check_for_stalled_tasks(self, create_dummy_dag, dag_maker, session, mock_celery_revoke): + create_dummy_dag(dag_id="test_clear_stalled", task_id="task1", with_dagrun_type=None) + dag_run = dag_maker.create_dagrun() - try_number = 1 + ti = dag_run.task_instances[0] + ti.state = State.QUEUED + ti.queued_dttm = timezone.utcnow() + ti.queued_by_job_id = 1 + ti.external_executor_id = '231' + session.flush() - with DAG("test_check_for_stalled_adopted_tasks") as dag: + executor = celery_executor.CeleryExecutor() + executor.job_id = 1 + executor.stalled_task_timeouts = { + ti.key: timezone.utcnow() - timedelta(days=1), + } + executor.running = {ti.key} + executor.tasks = {ti.key: AsyncResult("231")} + executor.sync() + assert executor.event_buffer == {} + assert executor.tasks == {} + assert executor.running == set() + assert executor.stalled_task_timeouts == {} + assert mock_celery_revoke.called_with("231") + + ti.refresh_from_db() + assert ti.state == State.SCHEDULED + assert ti.queued_by_job_id is None + assert ti.queued_dttm is None + assert ti.external_executor_id is None + + @pytest.mark.backend("mysql", "postgres") + @freeze_time("2020-01-01") + def test_pending_tasks_timeout_with_appropriate_config_setting(self): + start_date = timezone.utcnow() - timedelta(days=2) + + with DAG("test_check_for_stalled_tasks_are_ordered"): task_1 = BaseOperator(task_id="task_1", start_date=start_date) task_2 = BaseOperator(task_id="task_2", start_date=start_date) - key_1 = TaskInstanceKey(dag.dag_id, task_1.task_id, "runid", try_number) - key_2 = TaskInstanceKey(dag.dag_id, task_2.task_id, "runid", try_number) + ti1 = TaskInstance(task=task_1, run_id=None) + ti1.external_executor_id = '231' + ti1.state = State.QUEUED + ti2 = TaskInstance(task=task_2, run_id=None) + ti2.external_executor_id = '232' + ti2.state = State.QUEUED executor = celery_executor.CeleryExecutor() - executor.adopted_task_timeouts = { - key_2: queued_dttm_2 + executor.task_adoption_timeout, - key_1: queued_dttm + executor.task_adoption_timeout, + executor.stalled_task_timeout = timedelta(seconds=30) + executor.queued_tasks[ti2.key] = (None, None, None, None) + executor.try_adopt_task_instances([ti1]) + with mock.patch('airflow.executors.celery_executor.send_task_to_executor') as mock_send_task: + mock_send_task.return_value = (ti2.key, None, mock.MagicMock()) + executor._process_tasks([(ti2.key, None, None, mock.MagicMock())]) + assert executor.stalled_task_timeouts == { + ti2.key: timezone.utcnow() + timedelta(seconds=30), } - executor.running = {key_1, key_2} - executor.tasks = {key_1: AsyncResult("231"), key_2: AsyncResult("232")} - executor.sync() - assert executor.event_buffer == {key_1: (State.FAILED, None)} - assert executor.tasks == {key_2: AsyncResult('232')} - assert executor.running == {key_2} - assert executor.adopted_task_timeouts == {key_2: queued_dttm_2 + executor.task_adoption_timeout} + assert executor.adopted_task_timeouts == { + ti1.key: timezone.utcnow() + timedelta(seconds=600), + } + + @pytest.mark.backend("mysql", "postgres") + def test_no_pending_task_timeouts_when_configured(self): + start_date = timezone.utcnow() - timedelta(days=2) + + with DAG("test_check_for_stalled_tasks_are_ordered"): + task_1 = BaseOperator(task_id="task_1", start_date=start_date) + task_2 = BaseOperator(task_id="task_2", start_date=start_date) + + ti1 = TaskInstance(task=task_1, run_id=None) + ti1.external_executor_id = '231' + ti1.state = State.QUEUED + ti2 = TaskInstance(task=task_2, run_id=None) + ti2.external_executor_id = '232' + ti2.state = State.QUEUED + + executor = celery_executor.CeleryExecutor() + executor.task_adoption_timeout = timedelta(0) + executor.queued_tasks[ti2.key] = (None, None, None, None) + executor.try_adopt_task_instances([ti1]) + with mock.patch('airflow.executors.celery_executor.send_task_to_executor') as mock_send_task: + mock_send_task.return_value = (ti2.key, None, mock.MagicMock()) + executor._process_tasks([(ti2.key, None, None, mock.MagicMock())]) + assert executor.adopted_task_timeouts == {} + assert executor.stalled_task_timeouts == {} def test_operation_timeout_config():