diff --git a/dask_kubernetes/operator/controller/controller.py b/dask_kubernetes/operator/controller/controller.py index c34aa0ce..ef0c3d61 100644 --- a/dask_kubernetes/operator/controller/controller.py +++ b/dask_kubernetes/operator/controller/controller.py @@ -10,6 +10,7 @@ from uuid import uuid4 import aiohttp +import anyio import dask.config import kopf import kr8s @@ -646,16 +647,19 @@ async def daskworkergroup_replica_update( # Replica updates can come in quick succession and the changes must be applied atomically to ensure # the number of workers ends in the correct state async with worker_group_scale_locks[f"{namespace}/{name}"]: - current_workers = len( - await kr8s.asyncio.get( - "deployments", - namespace=namespace, - label_selector={"dask.org/workergroup-name": name}, - ) + current_workers = await kr8s.asyncio.get( + "deployments", + namespace=namespace, + label_selector={"dask.org/workergroup-name": name}, + ) + # Sorting workers to ensure long-lived workers are the first on list + current_workers = sorted( + current_workers, + key=lambda d: datetime.fromisoformat(d.metadata["creationTimestamp"]), ) assert isinstance(new, int) desired_workers = new - workers_needed = desired_workers - current_workers + workers_needed = desired_workers - len(current_workers) labels = _get_labels(meta) annotations = _get_annotations(meta) worker_spec = spec["worker"] @@ -695,22 +699,48 @@ async def daskworkergroup_replica_update( ) logger.info(f"Scaled worker group {name} up to {desired_workers} workers.") if workers_needed < 0: - worker_ids = await retire_workers( - n_workers=-workers_needed, - scheduler_service_name=SCHEDULER_NAME_TEMPLATE.format( - cluster_name=cluster_name - ), - worker_group_name=name, - namespace=namespace, - logger=logger, - ) - logger.info(f"Workers to close: {worker_ids}") - for wid in worker_ids: - worker_deployment = await Deployment(wid, namespace=namespace) - await worker_deployment.delete() - logger.info( - f"Scaled worker group {name} down to {desired_workers} workers." - ) + # We prioritize the deletion of newly created and unready deployments + recent_workers = current_workers[::-1] + + unready_deployments = [] + for idx in range(-workers_needed): + if idx > len(recent_workers): + break + deployment = recent_workers[idx] + if not ( + deployment.raw["status"].get("observedGeneration", 0) + >= deployment.raw["metadata"]["generation"] + and deployment.raw["status"].get("readyReplicas", 0) + == deployment.replicas + ): + unready_deployments.append(deployment) + + async with anyio.create_task_group() as tg: + for deployment in unready_deployments: + tg.start_soon(deployment.delete) + + if unready_deployments: + logger.info(f"Deleted unready {len(unready_deployments)} workers.") + + n_workers = -workers_needed - len(unready_deployments) + + if n_workers > 0: + worker_ids = await retire_workers( + n_workers=n_workers, + scheduler_service_name=SCHEDULER_NAME_TEMPLATE.format( + cluster_name=cluster_name + ), + worker_group_name=name, + namespace=namespace, + logger=logger, + ) + logger.info(f"Workers to close: {worker_ids}") + for wid in worker_ids: + worker_deployment = await Deployment(wid, namespace=namespace) + await worker_deployment.delete() + logger.info( + f"Scaled worker group {name} down to {desired_workers} workers." + ) @kopf.on.delete("daskworkergroup.kubernetes.dask.org", optional=True)