Skip to content

Commit

Permalink
make reference point optional (#601)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #601

This diff does two things: 1) if a user creates a `MultiObjectiveTorchModelbridge` and calls `gen` without specifying objective thresholds, then we infer the objective thresholds within the `MultiObjectiveBotorchModel`, generate candidates, and return the inferred objective thresholds in `gen_metadata`. 2) it adds a `infer_objective_thresholds` method to the `MultiObjectiveTorchModelbridge`, which can be used to infer objective thresholds without generating candidates.

This refactors the Base `Modelbridge.gen` and `ArrayModelbridge._gen` methods and to apply transformations within a utility function.

Note that this method returns ObservationData. If the user wants to plot outcomes with objective thresholds, the user would have to create the ObjectiveThresholds and set the objective thresholds on the optimization config.

Reviewed By: Balandat

Differential Revision: D28163744

fbshipit-source-id: c74908290bf6b7162771c0768e06bf8181fd4185
  • Loading branch information
sdaulton authored and facebook-github-bot committed Jun 19, 2021
1 parent d969efc commit 3d20dc0
Show file tree
Hide file tree
Showing 9 changed files with 748 additions and 64 deletions.
88 changes: 65 additions & 23 deletions ax/modelbridge/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Set, Tuple

import numpy as np
Expand Down Expand Up @@ -35,6 +36,18 @@
FIT_MODEL_ERROR = "Model must be fit before {action}."


@dataclass
class ArrayModelGenArgs:
search_space_digest: SearchSpaceDigest
objective_weights: np.ndarray
outcome_constraints: Optional[Tuple[np.ndarray, np.ndarray]]
linear_constraints: Optional[Tuple[np.ndarray, np.ndarray]]
fixed_features: Optional[Dict[int, float]]
pending_observations: Optional[List[np.ndarray]]
rounding_func: Callable[[np.ndarray], np.ndarray]
extra_model_gen_kwargs: Dict[str, Any]


# pyre-fixme[13]: Attribute `model` is never initialized.
# pyre-fixme[13]: Attribute `outcomes` is never initialized.
# pyre-fixme[13]: Attribute `parameters` is never initialized.
Expand Down Expand Up @@ -177,25 +190,14 @@ def _get_extra_model_gen_kwargs(
) -> Dict[str, Any]:
return {}

def _gen(
def _get_transformed_model_gen_args(
self,
n: int,
search_space: SearchSpace,
pending_observations: Dict[str, List[ObservationFeatures]],
fixed_features: ObservationFeatures,
model_gen_options: Optional[TConfig] = None,
optimization_config: Optional[OptimizationConfig] = None,
) -> Tuple[
List[ObservationFeatures],
List[float],
Optional[ObservationFeatures],
TGenMetadata,
]:
"""Generate new candidates according to search_space and
optimization_config.
The outcome constraints should be transformed to no longer be relative.
"""
) -> ArrayModelGenArgs:
# Validation
if not self.parameters: # pragma: no cover
raise ValueError(FIT_MODEL_ERROR.format(action="_gen"))
Expand Down Expand Up @@ -228,30 +230,70 @@ def _gen(
pending_array = pending_observations_as_array(
pending_observations, self.outcomes, self.parameters
)
# Generate the candidates
X, w, gen_metadata, candidate_metadata = self._model_gen(
n=n,
bounds=search_space_digest.bounds,
return ArrayModelGenArgs(
search_space_digest=search_space_digest,
objective_weights=objective_weights,
outcome_constraints=outcome_constraints,
linear_constraints=linear_constraints,
fixed_features=fixed_features_dict,
pending_observations=pending_array,
model_gen_options=model_gen_options,
rounding_func=transform_callback(self.parameters, self.transforms),
extra_model_gen_kwargs=extra_model_gen_kwargs,
)

def _gen(
self,
n: int,
search_space: SearchSpace,
pending_observations: Dict[str, List[ObservationFeatures]],
fixed_features: ObservationFeatures,
model_gen_options: Optional[TConfig] = None,
optimization_config: Optional[OptimizationConfig] = None,
) -> Tuple[
List[ObservationFeatures],
List[float],
Optional[ObservationFeatures],
TGenMetadata,
]:
"""Generate new candidates according to search_space and
optimization_config.
The outcome constraints should be transformed to no longer be relative.
"""
array_model_gen_args = self._get_transformed_model_gen_args(
search_space=search_space,
pending_observations=pending_observations,
fixed_features=fixed_features,
model_gen_options=model_gen_options,
optimization_config=optimization_config,
)

# Generate the candidates
search_space_digest = array_model_gen_args.search_space_digest
# TODO: pass array_model_gen_args to _model_gen
X, w, gen_metadata, candidate_metadata = self._model_gen(
n=n,
bounds=search_space_digest.bounds,
objective_weights=array_model_gen_args.objective_weights,
outcome_constraints=array_model_gen_args.outcome_constraints,
linear_constraints=array_model_gen_args.linear_constraints,
fixed_features=array_model_gen_args.fixed_features,
pending_observations=array_model_gen_args.pending_observations,
model_gen_options=model_gen_options,
rounding_func=array_model_gen_args.rounding_func,
target_fidelities=search_space_digest.target_fidelities,
**extra_model_gen_kwargs,
**array_model_gen_args.extra_model_gen_kwargs,
)
# Transform array to observations
observation_features = parse_observation_features(
X=X, param_names=self.parameters, candidate_metadata=candidate_metadata
)
xbest = self._model_best_point(
bounds=search_space_digest.bounds,
objective_weights=objective_weights,
outcome_constraints=outcome_constraints,
linear_constraints=linear_constraints,
fixed_features=fixed_features_dict,
objective_weights=array_model_gen_args.objective_weights,
outcome_constraints=array_model_gen_args.outcome_constraints,
linear_constraints=array_model_gen_args.linear_constraints,
fixed_features=array_model_gen_args.fixed_features,
model_gen_options=model_gen_options,
target_fidelities=search_space_digest.target_fidelities,
)
Expand Down
89 changes: 59 additions & 30 deletions ax/modelbridge/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from abc import ABC
from collections import OrderedDict
from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Dict, List, MutableMapping, Optional, Set, Tuple, Type

from ax.core.arm import Arm
Expand Down Expand Up @@ -43,6 +44,14 @@
logger = get_logger(__name__)


@dataclass
class BaseGenArgs:
search_space: SearchSpace
optimization_config: OptimizationConfig
pending_observations: Dict[str, List[ObservationFeatures]]
fixed_features: ObservationFeatures


class ModelBridge(ABC):
"""The main object for using models in Ax.
Expand Down Expand Up @@ -576,40 +585,18 @@ def _update(
"""
raise NotImplementedError # pragma: no cover

def gen(
def _get_transformed_gen_args(
self,
n: int,
search_space: Optional[SearchSpace] = None,
search_space: SearchSpace,
optimization_config: Optional[OptimizationConfig] = None,
pending_observations: Optional[Dict[str, List[ObservationFeatures]]] = None,
fixed_features: Optional[ObservationFeatures] = None,
model_gen_options: Optional[TConfig] = None,
) -> GeneratorRun:
"""
Args:
n: Number of points to generate
search_space: Search space
optimization_config: Optimization config
pending_observations: A map from metric name to pending
observations for that metric.
fixed_features: An ObservationFeatures object containing any
features that should be fixed at specified values during
generation.
model_gen_options: A config dictionary that is passed along to the
model.
"""
t_gen_start = time.time()
) -> BaseGenArgs:
if pending_observations is None:
pending_observations = {}
if fixed_features is None:
fixed_features = ObservationFeatures({})

# Get modifiable versions
if search_space is None:
search_space = self._model_space
orig_search_space = search_space
search_space = search_space.clone()

if optimization_config is None:
optimization_config = (
# pyre-fixme[16]: `Optional` has no attribute `clone`.
Expand All @@ -636,14 +623,55 @@ def gen(
for metric, po in pending_observations.items():
pending_observations[metric] = t.transform_observation_features(po)
fixed_features = t.transform_observation_features([fixed_features])[0]
return BaseGenArgs(
search_space=search_space,
optimization_config=optimization_config,
pending_observations=pending_observations,
fixed_features=fixed_features,
)

# Apply terminal transform and gen
observation_features, weights, best_obsf, gen_metadata = self._gen(
n=n,
def gen(
self,
n: int,
search_space: Optional[SearchSpace] = None,
optimization_config: Optional[OptimizationConfig] = None,
pending_observations: Optional[Dict[str, List[ObservationFeatures]]] = None,
fixed_features: Optional[ObservationFeatures] = None,
model_gen_options: Optional[TConfig] = None,
) -> GeneratorRun:
"""
Args:
n: Number of points to generate
search_space: Search space
optimization_config: Optimization config
pending_observations: A map from metric name to pending
observations for that metric.
fixed_features: An ObservationFeatures object containing any
features that should be fixed at specified values during
generation.
model_gen_options: A config dictionary that is passed along to the
model.
"""
t_gen_start = time.time()
# Get modifiable versions
if search_space is None:
search_space = self._model_space
orig_search_space = search_space
search_space = search_space.clone()
base_gen_args = self._get_transformed_gen_args(
search_space=search_space,
optimization_config=optimization_config,
pending_observations=pending_observations,
fixed_features=fixed_features,
)

# Apply terminal transform and gen
observation_features, weights, best_obsf, gen_metadata = self._gen(
n=n,
search_space=base_gen_args.search_space,
optimization_config=base_gen_args.optimization_config,
pending_observations=base_gen_args.pending_observations,
fixed_features=base_gen_args.fixed_features,
model_gen_options=model_gen_options,
)
# Apply reverse transforms
Expand Down Expand Up @@ -692,11 +720,12 @@ def gen(
immutable = getattr(
self, "_experiment_has_immutable_search_space_and_opt_config", False
)
optimization_config = None if immutable else base_gen_args.optimization_config
gr = GeneratorRun(
arms=arms,
weights=weights,
optimization_config=None if immutable else optimization_config,
search_space=None if immutable else search_space,
optimization_config=optimization_config,
search_space=None if immutable else base_gen_args.search_space,
model_predictions=model_predictions,
best_arm_predictions=None
if best_arm is None
Expand Down
Loading

0 comments on commit 3d20dc0

Please sign in to comment.