Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve MAPElites performance using torch_scatter #93

Open
JakeF-Bitweave opened this issue Oct 19, 2023 · 1 comment
Open

Improve MAPElites performance using torch_scatter #93

JakeF-Bitweave opened this issue Oct 19, 2023 · 1 comment

Comments

@JakeF-Bitweave
Copy link

JakeF-Bitweave commented Oct 19, 2023

Would you be interested in contributions to re-work MAPElites to use torch_scatter rather than the vmaped extended_population x feature_grid operation?

The general gist is:

  • Quantize features and round them to longs
  • Get the indexes for each elite for each volume using scatter_max (which returns argmaxes)
  • Map the evals and values from the extended population to the population using the indexes
@JakeF-Bitweave
Copy link
Author

Runnable comparison:

import math
import time
from typing import NamedTuple, List, Iterable

import torch
from torch_scatter import scatter_max, scatter_min
from evotorch import Problem
from evotorch.algorithms import MAPElites, SearchAlgorithm
from evotorch.algorithms.ga import ExtendedPopulationMixin
from evotorch.algorithms.searchalgorithm import SinglePopulationAlgorithmMixin
from evotorch.operators import GaussianMutation, SimulatedBinaryCrossOver


class FeatureGrid(NamedTuple):
    lower_bounds: List[float]
    upper_bounds: List[float]
    bins: List[int]


class MAPElitesScatter(MAPElites):
    def __init__(
        self,
        problem: Problem,
        *,
        operators: Iterable,
        feature_grid: FeatureGrid,
    ):
        problem.ensure_single_objective()
        problem.ensure_numeric()
        SearchAlgorithm.__init__(self, problem)
        self._sense = self._problem.senses[0]
        self._feature_grid = feature_grid
        self._popsize = math.prod(feature_grid.bins)
        self._population = problem.generate_batch(self._popsize)
        self._filled = torch.zeros(self._popsize, dtype=torch.bool, device=self._population.device)
        self._scatter_best = scatter_max if self._sense == "max" else scatter_min
        ExtendedPopulationMixin.__init__(
            self,
            re_evaluate=True,
            re_evaluate_parents_first=None,
            operators=operators,
            allow_empty_operators_list=False,
        )
        SinglePopulationAlgorithmMixin.__init__(self)

    def _step(self):
        # Form an extended population from the parents and from the children
        extended_population = self._make_extended_population(split=False)
        extended_pop_size = extended_population.eval_shape[0]

        all_evals = extended_population.evals.as_subclass(torch.Tensor)
        all_values = extended_population.values.as_subclass(torch.Tensor)
        all_fitnesses = all_evals[:, 0]
        feats = all_evals[:, 1:]
        device = all_evals.device

        hypervolume_index = torch.zeros(extended_pop_size, device=device, dtype=torch.long)
        widths = []
        for i, (lb, ub, n_bins) in enumerate(zip(*self._feature_grid)):
            diff = ub - lb
            const = n_bins / diff
            min_ = const * lb
            max_ = (const * ub) - 1

            feat = feats[:, i]

            feat *= const
            feat = torch.clamp_min(feat, min_)
            feat = torch.clamp_max(feat, max_)
            feat -= min_

            hypervolume_index += (feat.long() * math.prod(widths))
            widths.append(n_bins)

        # Find the best population members for each hypervolume
        _, argbest = self._scatter_best(all_fitnesses, hypervolume_index)

        # Filter hypervolumes that had no members
        all_index = argbest[argbest < extended_pop_size]
        index = torch.argwhere(argbest < extended_pop_size)[:, 0]

        # Build empty output
        values = torch.zeros((self._popsize, all_values.shape[1]), device=device, dtype=all_values.dtype)
        evals = torch.zeros((self._popsize, all_evals.shape[1]), device=device, dtype=all_evals.dtype)
        suitable = torch.zeros(self._popsize, device=device, dtype=torch.bool)

        # Map the members from the extended population to the output
        values[index] = all_values[all_index]
        evals[index] = all_evals[all_index]
        suitable[index] = True

        # Place the most suitable decision values and evaluation results into the current population.
        self._population.access_values(keep_evals=True)[:] = values
        self._population.access_evals()[:] = evals

        # If there was a suitable solution for the i-th cell, fill[i] is to be set as True.
        self._filled[:] = suitable


def kursawe(x: torch.Tensor) -> torch.Tensor:
    f1 = torch.sum(
        -10 * torch.exp(
            -0.2 * torch.sqrt(x[:, 0:2] ** 2.0 + x[:, 1:3] ** 2.0)
        ),
        dim=-1,
    )
    f2 = torch.sum(
        (torch.abs(x) ** 0.8) + (5 * torch.sin(x ** 3)),
        dim=-1,
    )
    fitnesses = torch.stack([f1 + f2, f1, f2], dim=-1)
    return fitnesses


if __name__ == "__main__":
    tensor_feature_grid = MAPElites.make_feature_grid(
        lower_bounds=[-20, -14],
        upper_bounds=[-10, 4],
        num_bins=50,
        dtype="float32",
    )

    for clazz, feature_grid in [
        (MAPElitesScatter, FeatureGrid([-20, -14], [-10, 4], [50, 50])),
        (MAPElites, tensor_feature_grid),
    ]:
        problem = Problem(
            "min",
            kursawe,
            solution_length=3,
            eval_data_length=2,
            bounds=(-5.0, 5.0),
            vectorized=True,
        )
        searcher = clazz(
            problem,
            feature_grid=feature_grid,
            operators=[
                SimulatedBinaryCrossOver(problem, tournament_size=4, cross_over_rate=1.0, eta=8),
                GaussianMutation(problem, stdev=0.03),
            ],
        )
        start = time.time()
        searcher.run(100)
        print("Final status:\n", searcher.status)
        print("Impl: ", clazz)
        print("Time spent (secs): ", time.time() - start)
        print("Filled hypervolumes: ", searcher.filled.sum())

out:

[2023-10-19 15:23:41] INFO     <46962> evotorch.core: Instance of `Problem` (id:6097491664) -- The `dtype` for the problem's decision variables is set as torch.float32
[2023-10-19 15:23:41] INFO     <46962> evotorch.core: Instance of `Problem` (id:6097491664) -- `eval_dtype` (the dtype of the fitnesses and evaluation data) is set as torch.float32
[2023-10-19 15:23:41] INFO     <46962> evotorch.core: Instance of `Problem` (id:6097491664) -- The `device` of the problem is set as cpu
[2023-10-19 15:23:41] INFO     <46962> evotorch.core: Instance of `Problem` (id:6097491664) -- The number of actors that will be allocated for parallelized evaluation is 0
Final status:
 <LazyStatusDict
    pop_best = <not yet computed>
    pop_best_eval = <not yet computed>
    mean_eval = <not yet computed>
    median_eval = <not yet computed>
    iter = 100
    best = <Solution values=tensor([-1.1392, -1.1283, -1.1402]), evals=tensor([-26.1042, -14.5122, -11.5920])>
    worst = <Solution values=tensor([ 4.7636, -4.5227, -4.6175]), evals=tensor([18.8853, -5.4335, 24.3187])>
    best_eval = -26.104228973388672
    worst_eval = 18.88526725769043
>
Impl:  <class '__main__.MAPElitesScatter'>
Time spent (secs):  0.2477729320526123
Filled hypervolumes:  ReadOnlyTensor(1562)
[2023-10-19 15:23:41] INFO     <46962> evotorch.core: Instance of `Problem` (id:5351110928) -- The `dtype` for the problem's decision variables is set as torch.float32
[2023-10-19 15:23:41] INFO     <46962> evotorch.core: Instance of `Problem` (id:5351110928) -- `eval_dtype` (the dtype of the fitnesses and evaluation data) is set as torch.float32
[2023-10-19 15:23:41] INFO     <46962> evotorch.core: Instance of `Problem` (id:5351110928) -- The `device` of the problem is set as cpu
[2023-10-19 15:23:41] INFO     <46962> evotorch.core: Instance of `Problem` (id:5351110928) -- The number of actors that will be allocated for parallelized evaluation is 0
Final status:
 <LazyStatusDict
    pop_best = <not yet computed>
    pop_best_eval = <not yet computed>
    mean_eval = <not yet computed>
    median_eval = <not yet computed>
    iter = 100
    best = <Solution values=tensor([-1.1402, -1.1233, -1.1448]), evals=tensor([-26.1034, -14.5165, -11.5869])>
    worst = <Solution values=tensor([ 4.4850, -4.8089, -4.8177]), evals=tensor([18.5221, -5.2474, 23.7694])>
    best_eval = -26.103425979614258
    worst_eval = 18.5220890045166
>
Impl:  <class 'evotorch.algorithms.mapelites.MAPElites'>
Time spent (secs):  8.833541870117188
Filled hypervolumes:  ReadOnlyTensor(1522)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant