diff --git a/ax/service/scheduler.py b/ax/service/scheduler.py index e8e3b8463bb..4f0dc839b49 100644 --- a/ax/service/scheduler.py +++ b/ax/service/scheduler.py @@ -6,7 +6,6 @@ from __future__ import annotations -import itertools from abc import ABC, abstractmethod from collections import defaultdict from dataclasses import dataclass diff --git a/ax/service/tests/test_scheduler.py b/ax/service/tests/test_scheduler.py index c3ba8de2262..3b522b6bb88 100644 --- a/ax/service/tests/test_scheduler.py +++ b/ax/service/tests/test_scheduler.py @@ -10,7 +10,7 @@ from math import ceil from random import randint from tempfile import NamedTemporaryFile -from typing import Any, Dict, Iterable, Optional, Set, Tuple +from typing import Any, Dict, Iterable, Set, Tuple from unittest.mock import patch from ax.core.base_trial import BaseTrial, TrialStatus @@ -201,6 +201,15 @@ def test_validate_runners_if_required(self): scheduler.run_all_trials() def test_validate_early_stopping_strategy(self): + class DummyEarlyStoppingStrategy(BaseEarlyStoppingStrategy): + def should_stop_trials_early( + self, + trial_indices: Set[int], + experiment: Experiment, + **kwargs: Dict[str, Any], + ) -> Set[int]: + return {} + with patch( f"{BraninMetric.__module__}.BraninMetric.is_available_while_running", return_value=False, @@ -209,7 +218,7 @@ def test_validate_early_stopping_strategy(self): experiment=self.branin_experiment, generation_strategy=self.sobol_GPEI_GS, options=SchedulerOptions( - early_stopping_strategy=BaseEarlyStoppingStrategy() + early_stopping_strategy=DummyEarlyStoppingStrategy() ), ) @@ -218,7 +227,7 @@ def test_validate_early_stopping_strategy(self): experiment=self.branin_experiment, generation_strategy=self.sobol_GPEI_GS, options=SchedulerOptions( - early_stopping_strategy=BaseEarlyStoppingStrategy() + early_stopping_strategy=DummyEarlyStoppingStrategy() ), )