From 06e93dc5ab97733cdcc6299ebfbfe03ef2fadeef Mon Sep 17 00:00:00 2001 From: Lena Kashtelyan Date: Wed, 23 Jun 2021 11:35:22 -0700 Subject: [PATCH] Add sweep messages to `ParamSweepScheduler` and set up for searcher explanation in outputs Summary: As titled Reviewed By: bernardbeckerman Differential Revision: D29245864 fbshipit-source-id: 1d1614535ded544892d72c05b8ef69f81e2fddac --- ax/service/scheduler.py | 49 ++++++++++++++++++++++++------ ax/service/tests/test_scheduler.py | 6 ++-- 2 files changed, 42 insertions(+), 13 deletions(-) diff --git a/ax/service/scheduler.py b/ax/service/scheduler.py index 60e72d312de..4f70fb3f11f 100644 --- a/ax/service/scheduler.py +++ b/ax/service/scheduler.py @@ -57,6 +57,11 @@ If this functionality is desired, specify the method in the scheduler subclass. """ +GS_TYPE_MSG = "This optimization run uses a '{gs_name}' generation strategy." +OPTIMIZATION_COMPLETION_MSG = """Optimization completed with total of {num_trials} +trials attached to the underlying Ax experiment '{experiment_name}'. +""" + # Wait time b/w polls will not exceed 5 mins. MAX_SECONDS_BETWEEN_POLLS = 300 @@ -215,6 +220,11 @@ class Scheduler(WithDBSettingsBase, ABC): generation_strategy: GenerationStrategy options: SchedulerOptions logger: LoggerAdapter + # Mapping of form {short string identifier -> message to show in reported + # results}. This is a mapping and not a list to allow for changing of + # some sweep messages throughout the course of the optimization (e.g. progress + # report of the optimization). + markdown_messages: Dict[str, str] # Number of trials that existed on the scheduler's experiment before # the scheduler instantiation with that experiment. @@ -280,6 +290,9 @@ def __init__( # when trials are not generated for the same reason multiple times in # a row. self._log_next_no_trials_reason = True + self.markdown_messages = { + "generation_strategy": GS_TYPE_MSG.format(gs_name=generation_strategy.name) + } @classmethod def get_default_db_settings(cls) -> DBSettings: @@ -686,7 +699,7 @@ def summarize_final_result(self) -> OptimizationResult: def run_trials_and_yield_results( self, max_trials: int, timeout_hours: Optional[int] = None - ) -> Generator[Dict[str, Any], None, Dict[str, Any]]: + ) -> Generator[Dict[str, Any], None, None]: """Make continuous calls to `run` and `process_results` to run up to ``max_trials`` trials, until completion criterion is reached. This is the 'main' method of a ``Scheduler``. @@ -717,13 +730,15 @@ def run_trials_and_yield_results( # schedule new trials and poll existing ones in a loop. while not self.completion_criterion() and len(trials) - n_existing < max_trials: if self.should_abort(): - res = self.report_results() + self._record_optimization_complete_message() self._record_run_trials_status( num_preexisting_trials=n_existing, status=RunTrialsStatus.ABORTED ) - # pyre-fixme[7]: Expected `Generator[Dict[str, typing.Any], None, - # None]` but got `Dict[str, typing.Any]`. T84274305 - return res + self._record_run_trials_status( + num_preexisting_trials=n_existing, status=RunTrialsStatus.ABORTED + ) + yield self.report_results()[1] + return # Run new trial evaluations until `run` returns `False`, which # means that there was a reason not to run more evaluations yet. @@ -734,7 +749,6 @@ def run_trials_and_yield_results( remaining_to_run = max_trials + n_existing - len(self.experiment.trials) # Wait for trial evaluations to complete and process results. - # pyre-fixme[7]: T84274305, as above yield self.wait_for_completed_trials_and_report_results() # When done scheduling, wait for the remaining trials to finish running @@ -746,22 +760,24 @@ def run_trials_and_yield_results( ) while self.running_trials: if self.should_abort(): - res = self.report_results() + self._record_optimization_complete_message() self._record_run_trials_status( num_preexisting_trials=n_existing, status=RunTrialsStatus.ABORTED ) - return res # pyre-fixme[7]: T84274305, as above + yield self.report_results()[1] + return - # pyre-fixme[7]: T84274305, as above yield self.wait_for_completed_trials_and_report_results() + self._record_optimization_complete_message() res = self.wait_for_completed_trials_and_report_results() # raise an error if the failure rate exceeds tolerance at the end of the sweep self.error_if_failure_rate_exceeded(force_check=True) self._record_run_trials_status( num_preexisting_trials=n_existing, status=RunTrialsStatus.SUCCESS ) - return res # pyre-fixme[7]: T84274305, as above + yield res + return def run_n_trials( self, max_trials: int, timeout_hours: Optional[int] = None @@ -1312,6 +1328,19 @@ def _record_run_trials_status( ] = new_trials self._append_to_experiment_properties(to_append=to_append) + def _record_optimization_complete_message(self) -> None: + """Adds a simple optimization completion message to this scheduler's markdown + messages. + """ + self.markdown_messages[ + "optimization_completion" + ] = OPTIMIZATION_COMPLETION_MSG.format( + num_trials=len(self.experiment.trials), + experiment_name=self.experiment.name + if self.experiment._name is not None + else "unnamed", + ) + def _append_to_experiment_properties(self, to_append: Dict[str, Any]) -> None: """Appends to list fields in experiment properties based on ``to_append`` input dict of form {property_name: value_to_append}. diff --git a/ax/service/tests/test_scheduler.py b/ax/service/tests/test_scheduler.py index e110494b679..6cc1ad4ffcc 100644 --- a/ax/service/tests/test_scheduler.py +++ b/ax/service/tests/test_scheduler.py @@ -615,7 +615,7 @@ def test_run_trials_and_yield_results(self): # as many times as `total_trials` and yielding from generator after # obtaining each new result. res_list = list(scheduler.run_trials_and_yield_results(max_trials=total_trials)) - self.assertEqual(len(res_list), total_trials) + self.assertEqual(len(res_list), total_trials + 1) self.assertIsInstance(res_list, list) self.assertEqual(len(res_list[0]["trials_completed_so_far"]), 1) self.assertEqual(len(res_list[1]["trials_completed_so_far"]), 2) @@ -650,7 +650,7 @@ def should_stop_trials_early(self, trial_indices: Set[int]): ) # Two steps complete the experiment given parallelism. expected_num_polls = 2 - self.assertEqual(len(res_list), expected_num_polls) + self.assertEqual(len(res_list), expected_num_polls + 1) self.assertIsInstance(res_list, list) # Both trials in first batch of parallelism will be early stopped self.assertEqual(len(res_list[0]["trials_early_stopped_so_far"]), 2) @@ -712,7 +712,7 @@ def poll_trial_status(self): scheduler.run_trials_and_yield_results(max_trials=total_trials) ) expected_num_steps = 2 - self.assertEqual(len(res_list), expected_num_steps) + self.assertEqual(len(res_list), expected_num_steps + 1) # Trial #1 early stopped in first step self.assertEqual(res_list[0]["trials_early_stopped_so_far"], {1}) # All trials completed by end of second step