Skip to content

Commit

Permalink
Add sweep messages to ParamSweepScheduler and set up for searcher e…
Browse files Browse the repository at this point in the history
…xplanation in outputs

Summary: As titled

Reviewed By: bernardbeckerman

Differential Revision: D29245864

fbshipit-source-id: 1d1614535ded544892d72c05b8ef69f81e2fddac
  • Loading branch information
lena-kashtelyan authored and facebook-github-bot committed Jun 23, 2021
1 parent 2317218 commit 06e93dc
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 13 deletions.
49 changes: 39 additions & 10 deletions ax/service/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@
If this functionality is desired, specify the method in the
scheduler subclass.
"""
GS_TYPE_MSG = "This optimization run uses a '{gs_name}' generation strategy."
OPTIMIZATION_COMPLETION_MSG = """Optimization completed with total of {num_trials}
trials attached to the underlying Ax experiment '{experiment_name}'.
"""


# Wait time b/w polls will not exceed 5 mins.
MAX_SECONDS_BETWEEN_POLLS = 300
Expand Down Expand Up @@ -215,6 +220,11 @@ class Scheduler(WithDBSettingsBase, ABC):
generation_strategy: GenerationStrategy
options: SchedulerOptions
logger: LoggerAdapter
# Mapping of form {short string identifier -> message to show in reported
# results}. This is a mapping and not a list to allow for changing of
# some sweep messages throughout the course of the optimization (e.g. progress
# report of the optimization).
markdown_messages: Dict[str, str]

# Number of trials that existed on the scheduler's experiment before
# the scheduler instantiation with that experiment.
Expand Down Expand Up @@ -280,6 +290,9 @@ def __init__(
# when trials are not generated for the same reason multiple times in
# a row.
self._log_next_no_trials_reason = True
self.markdown_messages = {
"generation_strategy": GS_TYPE_MSG.format(gs_name=generation_strategy.name)
}

@classmethod
def get_default_db_settings(cls) -> DBSettings:
Expand Down Expand Up @@ -686,7 +699,7 @@ def summarize_final_result(self) -> OptimizationResult:

def run_trials_and_yield_results(
self, max_trials: int, timeout_hours: Optional[int] = None
) -> Generator[Dict[str, Any], None, Dict[str, Any]]:
) -> Generator[Dict[str, Any], None, None]:
"""Make continuous calls to `run` and `process_results` to run up to
``max_trials`` trials, until completion criterion is reached. This is the 'main'
method of a ``Scheduler``.
Expand Down Expand Up @@ -717,13 +730,15 @@ def run_trials_and_yield_results(
# schedule new trials and poll existing ones in a loop.
while not self.completion_criterion() and len(trials) - n_existing < max_trials:
if self.should_abort():
res = self.report_results()
self._record_optimization_complete_message()
self._record_run_trials_status(
num_preexisting_trials=n_existing, status=RunTrialsStatus.ABORTED
)
# pyre-fixme[7]: Expected `Generator[Dict[str, typing.Any], None,
# None]` but got `Dict[str, typing.Any]`. T84274305
return res
self._record_run_trials_status(
num_preexisting_trials=n_existing, status=RunTrialsStatus.ABORTED
)
yield self.report_results()[1]
return

# Run new trial evaluations until `run` returns `False`, which
# means that there was a reason not to run more evaluations yet.
Expand All @@ -734,7 +749,6 @@ def run_trials_and_yield_results(
remaining_to_run = max_trials + n_existing - len(self.experiment.trials)

# Wait for trial evaluations to complete and process results.
# pyre-fixme[7]: T84274305, as above
yield self.wait_for_completed_trials_and_report_results()

# When done scheduling, wait for the remaining trials to finish running
Expand All @@ -746,22 +760,24 @@ def run_trials_and_yield_results(
)
while self.running_trials:
if self.should_abort():
res = self.report_results()
self._record_optimization_complete_message()
self._record_run_trials_status(
num_preexisting_trials=n_existing, status=RunTrialsStatus.ABORTED
)
return res # pyre-fixme[7]: T84274305, as above
yield self.report_results()[1]
return

# pyre-fixme[7]: T84274305, as above
yield self.wait_for_completed_trials_and_report_results()

self._record_optimization_complete_message()
res = self.wait_for_completed_trials_and_report_results()
# raise an error if the failure rate exceeds tolerance at the end of the sweep
self.error_if_failure_rate_exceeded(force_check=True)
self._record_run_trials_status(
num_preexisting_trials=n_existing, status=RunTrialsStatus.SUCCESS
)
return res # pyre-fixme[7]: T84274305, as above
yield res
return

def run_n_trials(
self, max_trials: int, timeout_hours: Optional[int] = None
Expand Down Expand Up @@ -1312,6 +1328,19 @@ def _record_run_trials_status(
] = new_trials
self._append_to_experiment_properties(to_append=to_append)

def _record_optimization_complete_message(self) -> None:
"""Adds a simple optimization completion message to this scheduler's markdown
messages.
"""
self.markdown_messages[
"optimization_completion"
] = OPTIMIZATION_COMPLETION_MSG.format(
num_trials=len(self.experiment.trials),
experiment_name=self.experiment.name
if self.experiment._name is not None
else "unnamed",
)

def _append_to_experiment_properties(self, to_append: Dict[str, Any]) -> None:
"""Appends to list fields in experiment properties based on ``to_append``
input dict of form {property_name: value_to_append}.
Expand Down
6 changes: 3 additions & 3 deletions ax/service/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ def test_run_trials_and_yield_results(self):
# as many times as `total_trials` and yielding from generator after
# obtaining each new result.
res_list = list(scheduler.run_trials_and_yield_results(max_trials=total_trials))
self.assertEqual(len(res_list), total_trials)
self.assertEqual(len(res_list), total_trials + 1)
self.assertIsInstance(res_list, list)
self.assertEqual(len(res_list[0]["trials_completed_so_far"]), 1)
self.assertEqual(len(res_list[1]["trials_completed_so_far"]), 2)
Expand Down Expand Up @@ -650,7 +650,7 @@ def should_stop_trials_early(self, trial_indices: Set[int]):
)
# Two steps complete the experiment given parallelism.
expected_num_polls = 2
self.assertEqual(len(res_list), expected_num_polls)
self.assertEqual(len(res_list), expected_num_polls + 1)
self.assertIsInstance(res_list, list)
# Both trials in first batch of parallelism will be early stopped
self.assertEqual(len(res_list[0]["trials_early_stopped_so_far"]), 2)
Expand Down Expand Up @@ -712,7 +712,7 @@ def poll_trial_status(self):
scheduler.run_trials_and_yield_results(max_trials=total_trials)
)
expected_num_steps = 2
self.assertEqual(len(res_list), expected_num_steps)
self.assertEqual(len(res_list), expected_num_steps + 1)
# Trial #1 early stopped in first step
self.assertEqual(res_list[0]["trials_early_stopped_so_far"], {1})
# All trials completed by end of second step
Expand Down

0 comments on commit 06e93dc

Please sign in to comment.