Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatically reschedule stalled queued tasks in CeleryExecutor (v2) #23690

Merged
merged 8 commits into from
May 20, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions airflow/config_templates/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1768,12 +1768,23 @@
default: "True"
- name: task_adoption_timeout
description: |
Time in seconds after which Adopted tasks are cleared by CeleryExecutor. This is helpful to clear
stalled tasks.
Time in seconds after which adopted tasks which are queued in celery are assumed to be stalled,
and are automatically rescheduled. This setting does the same thing as ``stalled_task_timeout`` but
applies specifically to adopted tasks only. When set to 0, the ``stalled_task_timeout`` setting
also applies to adopted tasks.
version_added: 2.0.0
type: integer
example: ~
default: "600"
- name: stalled_task_timeout
description: |
Time in seconds after which tasks queued in celery are assumed to be stalled, and are automatically
rescheduled. Adopted tasks will instead use the ``task_adoption_timeout`` setting if specified.
When set to 0, automatic clearing of stalled tasks is disabled.
version_added: 2.3.1
type: integer
example: ~
default: "0"
- name: task_publish_max_retries
description: |
The Maximum number of retries for publishing task messages to the broker when failing
Expand Down
11 changes: 9 additions & 2 deletions airflow/config_templates/default_airflow.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -888,10 +888,17 @@ operation_timeout = 1.0
# or run in HA mode, it can adopt the orphan tasks launched by previous SchedulerJob.
task_track_started = True

# Time in seconds after which Adopted tasks are cleared by CeleryExecutor. This is helpful to clear
# stalled tasks.
# Time in seconds after which adopted tasks which are queued in celery are assumed to be stalled,
# and are automatically rescheduled. This setting does the same thing as ``stalled_task_timeout`` but
# applies specifically to adopted tasks only. When set to 0, the ``stalled_task_timeout`` setting
# also applies to adopted tasks.
task_adoption_timeout = 600

# Time in seconds after which tasks queued in celery are assumed to be stalled, and are automatically
# rescheduled. Adopted tasks will instead use the ``task_adoption_timeout`` setting if specified.
# When set to 0, automatic clearing of stalled tasks is disabled.
stalled_task_timeout = 0

# The Maximum number of retries for publishing task messages to the broker when failing
# due to ``AirflowTaskTimeout`` error before giving up and marking Task as failed.
task_publish_max_retries = 3
Expand Down
138 changes: 107 additions & 31 deletions airflow/executors/celery_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@
import subprocess
import time
import traceback
from collections import Counter, OrderedDict
from collections import Counter
from concurrent.futures import ProcessPoolExecutor
from enum import Enum
from multiprocessing import cpu_count
from typing import Any, Dict, List, Mapping, MutableMapping, Optional, Set, Tuple, Union

Expand All @@ -40,6 +41,7 @@
from celery.result import AsyncResult
from celery.signals import import_modules as celery_import_modules
from setproctitle import setproctitle
from sqlalchemy.orm.session import Session

import airflow.settings as settings
from airflow.config_templates.default_celery import DEFAULT_CELERY_CONFIG
Expand All @@ -50,6 +52,7 @@
from airflow.stats import Stats
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.net import get_hostname
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import State
from airflow.utils.timeout import timeout
from airflow.utils.timezone import utcnow
Expand Down Expand Up @@ -207,6 +210,11 @@ def on_celery_import_modules(*args, **kwargs):
pass


class _CeleryPendingTaskTimeoutType(Enum):
ADOPTED = 1
STALLED = 2


class CeleryExecutor(BaseExecutor):
"""
CeleryExecutor is recommended for production use of Airflow. It allows
Expand All @@ -230,10 +238,14 @@ def __init__(self):
self._sync_parallelism = max(1, cpu_count() - 1)
self.bulk_state_fetcher = BulkStateFetcher(self._sync_parallelism)
self.tasks = {}
# Mapping of tasks we've adopted, ordered by the earliest date they timeout
self.adopted_task_timeouts: Dict[TaskInstanceKey, datetime.datetime] = OrderedDict()
self.task_adoption_timeout = datetime.timedelta(
seconds=conf.getint('celery', 'task_adoption_timeout', fallback=600)
self.stalled_task_timeouts: Dict[TaskInstanceKey, datetime.datetime] = {}
self.stalled_task_timeout = datetime.timedelta(
seconds=conf.getint('celery', 'stalled_task_timeout', fallback=0)
)
self.adopted_task_timeouts: Dict[TaskInstanceKey, datetime.datetime] = {}
self.task_adoption_timeout = (
datetime.timedelta(seconds=conf.getint('celery', 'task_adoption_timeout', fallback=600))
or self.stalled_task_timeout
)
self.task_publish_retries: Counter[TaskInstanceKey] = Counter()
self.task_publish_max_retries = conf.getint('celery', 'task_publish_max_retries', fallback=3)
Expand Down Expand Up @@ -285,6 +297,7 @@ def _process_tasks(self, task_tuples: List[TaskTuple]) -> None:
result.backend = cached_celery_backend
self.running.add(key)
self.tasks[key] = result
self._set_celery_pending_task_timeout(key, _CeleryPendingTaskTimeoutType.STALLED)

# Store the Celery task_id in the event buffer. This will get "overwritten" if the task
# has another event, but that is fine, because the only other events are success/failed at
Expand Down Expand Up @@ -315,25 +328,47 @@ def sync(self) -> None:
self.log.debug("No task to query celery, skipping sync")
return
self.update_all_task_states()
self._check_for_timedout_adopted_tasks()
self._check_for_stalled_tasks()

if self.adopted_task_timeouts:
self._check_for_stalled_adopted_tasks()
def _check_for_timedout_adopted_tasks(self) -> None:
timedout_keys = self._get_timedout_ti_keys(self.adopted_task_timeouts)
if timedout_keys:
self.log.error(
"Adopted tasks were still pending after %s, assuming they never made it to celery "
"and sending back to the scheduler:\n\t%s",
self.task_adoption_timeout,
"\n\t".join(repr(x) for x in timedout_keys),
)
self._send_stalled_tis_back_to_scheduler(timedout_keys)

def _check_for_stalled_adopted_tasks(self):
"""
See if any of the tasks we adopted from another Executor run have not
progressed after the configured timeout.
def _check_for_stalled_tasks(self) -> None:
timedout_keys = self._get_timedout_ti_keys(self.stalled_task_timeouts)
if timedout_keys:
self.log.error(
"Tasks were still pending after %s, assuming they never made it to celery "
"and sending back to the scheduler:\n\t%s",
self.stalled_task_timeout,
"\n\t".join(repr(x) for x in timedout_keys),
)
self._send_stalled_tis_back_to_scheduler(timedout_keys)

If they haven't, they likely never made it to Celery, and we should
just resend them. We do that by clearing the state and letting the
normal scheduler loop deal with that
def _get_timedout_ti_keys(
self, task_timeouts: Dict[TaskInstanceKey, datetime.datetime]
) -> List[TaskInstanceKey]:
"""
These timeouts exist to check to see if any of our tasks have not progressed
in the expected time. This can happen for few different reasons, usually related
to race conditions while shutting down schedulers and celery workers.

It is, of course, always possible that these tasks are not actually
stalled - they could just be waiting in a long celery queue.
Unfortunately there's no way for us to know for sure, so we'll just
reschedule them and let the normal scheduler loop requeue them.
"""
now = utcnow()

sorted_adopted_task_timeouts = sorted(self.adopted_task_timeouts.items(), key=lambda k: k[1])

timedout_keys = []
for key, stalled_after in sorted_adopted_task_timeouts:
for key, stalled_after in task_timeouts.items():
if stalled_after > now:
# Since items are stored sorted, if we get to a stalled_after
# in the future then we can stop
Expand All @@ -343,20 +378,46 @@ def _check_for_stalled_adopted_tasks(self):
# already finished, then it will be removed from this list -- so
# the only time it's still in this list is when it a) never made it
# to celery in the first place (i.e. race condition somewhere in
# the dying executor) or b) a really long celery queue and it just
# the dying executor), b) celery lost the task before execution
# started, or c) a really long celery queue and it just
# hasn't started yet -- better cancel it and let the scheduler
# re-queue rather than have this task risk stalling for ever
timedout_keys.append(key)
return timedout_keys

if timedout_keys:
self.log.error(
"Adopted tasks were still pending after %s, assuming they never made it to celery and "
"clearing:\n\t%s",
self.task_adoption_timeout,
"\n\t".join(repr(x) for x in timedout_keys),
@provide_session
def _send_stalled_tis_back_to_scheduler(
self, keys: List[TaskInstanceKey], session: Session = NEW_SESSION
) -> None:
try:
session.query(TaskInstance).filter(
TaskInstance.filter_for_tis(keys),
TaskInstance.state == State.QUEUED,
TaskInstance.queued_by_job_id == self.job_id,
ashb marked this conversation as resolved.
Show resolved Hide resolved
).update(
{
TaskInstance.state: State.SCHEDULED,
TaskInstance.queued_dttm: None,
TaskInstance.queued_by_job_id: None,
TaskInstance.external_executor_id: None,
},
synchronize_session=False,
)
for key in timedout_keys:
self.change_state(key, State.FAILED)
session.commit()
except Exception:
self.log.exception("Error sending tasks back to scheduler")
session.rollback()
return

for key in keys:
self._set_celery_pending_task_timeout(key, None)
self.running.discard(key)
celery_async_result = self.tasks.pop(key, None)
if celery_async_result:
try:
app.control.revoke(celery_async_result.task_id)
except Exception as ex:
self.log.error("Error revoking task instance %s from celery: %s", key, ex)

def debug_dump(self) -> None:
"""Called in response to SIGUSR2 by the scheduler"""
Expand All @@ -369,6 +430,11 @@ def debug_dump(self) -> None:
len(self.adopted_task_timeouts),
"\n\t".join(map(repr, self.adopted_task_timeouts.items())),
)
self.log.info(
"executor.stalled_task_timeouts (%d)\n\t%s",
len(self.stalled_task_timeouts),
"\n\t".join(map(repr, self.stalled_task_timeouts.items())),
)

def update_all_task_states(self) -> None:
"""Updates states of the tasks."""
Expand All @@ -384,7 +450,7 @@ def update_all_task_states(self) -> None:
def change_state(self, key: TaskInstanceKey, state: str, info=None) -> None:
super().change_state(key, state, info)
self.tasks.pop(key, None)
self.adopted_task_timeouts.pop(key, None)
self._set_celery_pending_task_timeout(key, None)

def update_task_state(self, key: TaskInstanceKey, state: str, info: Any) -> None:
"""Updates state of a single task."""
Expand All @@ -394,8 +460,8 @@ def update_task_state(self, key: TaskInstanceKey, state: str, info: Any) -> None
elif state in (celery_states.FAILURE, celery_states.REVOKED):
self.fail(key, info)
elif state == celery_states.STARTED:
# It's now actually running, so know it made it to celery okay!
self.adopted_task_timeouts.pop(key, None)
# It's now actually running, so we know it made it to celery okay!
self._set_celery_pending_task_timeout(key, None)
elif state == celery_states.PENDING:
pass
else:
Expand Down Expand Up @@ -455,7 +521,7 @@ def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance

# Set the correct elements of the state dicts, then update this
# like we just queried it.
self.adopted_task_timeouts[ti.key] = ti.queued_dttm + self.task_adoption_timeout
self._set_celery_pending_task_timeout(ti.key, _CeleryPendingTaskTimeoutType.ADOPTED)
self.tasks[ti.key] = result
self.running.add(ti.key)
self.update_task_state(ti.key, state, info)
Expand All @@ -469,6 +535,16 @@ def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance

return not_adopted_tis

def _set_celery_pending_task_timeout(
ashb marked this conversation as resolved.
Show resolved Hide resolved
self, key: TaskInstanceKey, timeout_type: Optional[_CeleryPendingTaskTimeoutType]
) -> None:
self.adopted_task_timeouts.pop(key, None)
self.stalled_task_timeouts.pop(key, None)
if timeout_type == _CeleryPendingTaskTimeoutType.ADOPTED and self.task_adoption_timeout:
self.adopted_task_timeouts[key] = utcnow() + self.task_adoption_timeout
elif timeout_type == _CeleryPendingTaskTimeoutType.STALLED and self.stalled_task_timeout:
self.stalled_task_timeouts[key] = utcnow() + self.stalled_task_timeout


def fetch_celery_task_state(async_result: AsyncResult) -> Tuple[str, Union[str, ExceptionWithTraceback], Any]:
"""
Expand Down
Loading