Skip to content

Commit

Permalink
Remove MOOAcquisition
Browse files Browse the repository at this point in the history
Summary: There is no longer need for separate acquisition wrapper class for MOO

Reviewed By: Balandat

Differential Revision: D28207042

fbshipit-source-id: 830745ebcb18b2cc8c4e0ea305175395b236adbd
  • Loading branch information
lena-kashtelyan authored and facebook-github-bot committed May 21, 2021
1 parent 715e2ef commit 5cd1bd6
Show file tree
Hide file tree
Showing 7 changed files with 167 additions and 398 deletions.
1 change: 1 addition & 0 deletions ax/models/torch/botorch_modular/acquisition.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,5 +307,6 @@ def get_botorch_objective(
botorch_acqf_class, AnalyticAcquisitionFunction
),
outcome_constraints=outcome_constraints,
objective_thresholds=objective_thresholds,
X_observed=X_observed,
)
14 changes: 13 additions & 1 deletion ax/models/torch/botorch_modular/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,8 +389,20 @@ def _instantiate_acquisition(
pending_observations: Optional[List[Tensor]] = None,
acq_options: Optional[Dict[str, Any]] = None,
) -> Acquisition:
"""Set an BoTorch acquisition function class for this model if needed and
instantiate it.
Returns:
BoTorch ``AcquisitionFunction`` instance.
"""
if not self._botorch_acqf_class:
self._botorch_acqf_class = choose_botorch_acqf_class()
self._botorch_acqf_class = choose_botorch_acqf_class(
objective_thresholds=objective_thresholds,
outcome_constraints=outcome_constraints,
linear_constraints=linear_constraints,
fixed_features=fixed_features,
pending_observations=pending_observations,
)

return self.acquisition_class(
surrogate=self.surrogate,
Expand Down
135 changes: 0 additions & 135 deletions ax/models/torch/botorch_modular/moo_acquisition.py

This file was deleted.

20 changes: 15 additions & 5 deletions ax/models/torch/botorch_modular/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import List, Optional, Tuple, Type
from typing import List, Dict, Optional, Tuple, Type

import torch
from ax.core.search_space import SearchSpaceDigest
Expand All @@ -13,6 +13,9 @@
from ax.utils.common.typeutils import checked_cast
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.acquisition.monte_carlo import qNoisyExpectedImprovement
from botorch.acquisition.multi_objective.monte_carlo import (
qExpectedHypervolumeImprovement,
)
from botorch.models.gp_regression import FixedNoiseGP, SingleTaskGP
from botorch.models.gp_regression_fidelity import (
FixedNoiseMultiFidelityGP,
Expand Down Expand Up @@ -102,11 +105,18 @@ def choose_model_class(
return FixedNoiseGP # Known observation noise.


def choose_botorch_acqf_class() -> Type[AcquisitionFunction]:
def choose_botorch_acqf_class(
pending_observations: Optional[List[Tensor]] = None,
outcome_constraints: Optional[Tuple[Tensor, Tensor]] = None,
linear_constraints: Optional[Tuple[Tensor, Tensor]] = None,
fixed_features: Optional[Dict[int, float]] = None,
objective_thresholds: Optional[Tensor] = None,
) -> Type[AcquisitionFunction]:
"""Chooses a BoTorch `AcquisitionFunction` class."""
# NOTE: In the future, this dispatch function could leverage any
# of the attributes of `BoTorchModel` or kwargs passed to
# `BoTorchModel.gen` to intelligently select acquisition function.
if objective_thresholds is not None:
# TODO: Use new qNEHVI just added to BoTorch once we have `construct_inputs`
# for it in BoTorch.
return qExpectedHypervolumeImprovement
return qNoisyExpectedImprovement


Expand Down
Loading

0 comments on commit 5cd1bd6

Please sign in to comment.