Skip to content

Commit

Permalink
tooling to help suggest the best locations for checkpoints (facebookr…
Browse files Browse the repository at this point in the history
…esearch#136)

Summary:
Addition of functions to suggest the best places to split the accumulation of activations. This provides the boundaries of the `checkpoint_wrapper` to insert in the model to limit its activation memory accumulation.

The location of the checkpoint is not perfect because:

1. it does not take into account the accumulation of gradients in the backward pass (which tends to minimise the need for the checkpoints at the end of the model, i.e. the first checkpoints to be traversed in the backward pass)
2. it does not take into account code constraints such as "it's hard to split exactly there, let's split further"

But it tends to give a good starting point.

**Example**: I used this tooling to compute the best place to allocate checkpoints with results such as this:

<img width="498" alt="Screenshot 2021-05-04 at 18 17 50" src="https://user-images.githubusercontent.com/7412790/117146564-58acb780-ad82-11eb-94a3-1b6be4a9997e.png">

As the size of the model decreases in comparison to the activations (the more we shard a model or increase the batch size), these suggestions tends to the optimal configuration.

CC: min-xu-ai prigoyal

Pull Request resolved: fairinternal/ssl_scaling#136

Reviewed By: prigoyal

Differential Revision: D28222202

Pulled By: QuentinDuval

fbshipit-source-id: 12355db21e01e27f99c2152c26857a41de94d376
  • Loading branch information
QuentinDuval authored and facebook-github-bot committed May 5, 2021
1 parent 297b505 commit 28de28e
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 17 deletions.
1 change: 1 addition & 0 deletions dev/run_quick_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ TEST_LIST=(
"test_regnet_fsdp.py"
"test_regnet_fsdp_integration.py"
"test_state_checkpointing.py"
"test_layer_memory_tracking.py"
)

echo "========================================================================"
Expand Down
81 changes: 64 additions & 17 deletions tests/test_layer_memory_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,20 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import contextlib
import unittest

import torch
import torch.nn as nn
import torchvision.models as models
from vissl.utils.layer_memory_tracking import LayerwiseMemoryTracker
from vissl.utils.layer_memory_tracking import (
LayerwiseMemoryTracker,
find_best_reset_points,
)
from vissl.utils.test_utils import gpu_test, with_timing


class TestLayerMemoryTracking(unittest.TestCase):
@unittest.skipIf(not torch.cuda.is_available(), "Test requires at least 1 GPU")
@gpu_test(gpu_count=1)
def test_memory_tracking(self):

# Create a model with a hierarchy of modules
Expand Down Expand Up @@ -70,24 +73,68 @@ def test_memory_tracking(self):
for trace in top_act_producers:
self.assertEqual(25233408, trace.event.memory_activations)

@contextlib.contextmanager
def with_timing(self, name: str):
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
yield
end_event.record()
torch.cuda.synchronize() # Wait for the events to be recorded!
elapsed_time_ms = start_event.elapsed_time(end_event)
print(name, ":", elapsed_time_ms, "ms")

@unittest.skipIf(not torch.cuda.is_available(), "Test requires at least 1 GPU")
@gpu_test(gpu_count=1)
def test_memory_tracking_performance_impact(self):
torch.manual_seed(0)
model = models.resnet18()
with self.with_timing("no_tracking"):
with with_timing("no_tracking"):
model(torch.randn(size=(1, 3, 224, 224)))
with self.with_timing("with_tracking"):
with with_timing("with_tracking"):
tracker = LayerwiseMemoryTracker()
tracker.monitor(model)
model(torch.randn(size=(1, 3, 224, 224)))

def test_find_best_reset_points(self):
"""
Verify that the reset points are correctly computed
"""
activations = [10, 8, 8, 9, 7, 7, 5, 4, 4]

# Check boundary condition: no checkpoints
memory, split_points = find_best_reset_points(activations, nb_checkpoints=0)
self.assertEqual(memory, sum(activations))

# Check boundary condition: checkpoints everywhere
memory, split_points = find_best_reset_points(
activations, nb_checkpoints=len(activations)
)
self.assertEqual(memory, max(activations))

# Check one checkpoint allocation
memory, split_points = find_best_reset_points(activations, nb_checkpoints=1)
self.assertEqual(memory, 35)
self.assertEqual(split_points, [4])
self.assertEqual(sum(activations[: split_points[0]]), 35)
self.assertEqual(sum(activations[split_points[0] :]), 27)

# Check multiple checkpoint allocation
memory, split_points = find_best_reset_points(activations, nb_checkpoints=2)
self.assertEqual(memory, 24)
delimiters = [0] + split_points + [len(activations)]
splits_memory = [
sum(activations[i:j]) for i, j in zip(delimiters[:-1], delimiters[1:])
]
self.assertEqual(max(splits_memory), memory)

@gpu_test(gpu_count=1)
def test_find_best_reset_points_performance(self):
"""
Test that the algorithm is O(N**2) complexity for N activations
"""
import numpy as np

activations_1000 = list(np.random.randint(low=0, high=1_000_000, size=1_000))
activations_2000 = list(np.random.randint(low=0, high=1_000_000, size=2_000))
nb_checkpoints = 10
with with_timing(name="best_reset_points_1000") as timer_1000:
find_best_reset_points(activations_1000, nb_checkpoints=nb_checkpoints)
with with_timing(name="best_reset_points_2000") as timer_2000:
find_best_reset_points(activations_2000, nb_checkpoints=nb_checkpoints)
self.assertGreaterEqual(timer_2000.elapsed_time_ms, timer_1000.elapsed_time_ms)
self.assertLessEqual(timer_2000.elapsed_time_ms, timer_1000.elapsed_time_ms * 6)


if __name__ == "__main__":
test = TestLayerMemoryTracking()
test.test_find_best_reset_points()
test.test_find_best_reset_points_performance()
87 changes: 87 additions & 0 deletions vissl/utils/layer_memory_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from dataclasses import dataclass
from enum import Enum, auto
from functools import lru_cache
from typing import Dict, List, NamedTuple, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -418,6 +419,92 @@ def _collect_tensors(module_io_tensors):
return tensors


def find_best_reset_points(
activation_sizes: List[int], nb_checkpoints: int
) -> Tuple[int, List[int]]:
"""
Assuming constant memory requirement from the model, its gradients
and the associated optimizer state (realistic for small models
or models that are sharded enough to be considered small), this
function computes the ideal placement for the checkpoints by
returning the limits at which we should reset memory.
"""
n = len(activation_sizes)

@lru_cache(maxsize=None)
def visit(pos: int, remaining: int) -> Tuple[int, List[int]]:
if pos == n:
return 0, []
if remaining == 0:
return sum(activation_sizes[pos:]), []

min_val = float("inf")
allocation = []

current_chunk = 0
for curr_pos in range(pos, n):
current_chunk += activation_sizes[curr_pos]
sub_result, sub_alloc = visit(curr_pos + 1, remaining - 1)
result = max(current_chunk, sub_result)
if result < min_val:
min_val = result
allocation = list(sub_alloc)
allocation.append(curr_pos + 1)

return min_val, allocation

best_score, best_allocation = visit(0, nb_checkpoints)
return best_score, best_allocation[::-1]


@dataclass
class SuggestedCheckpoints:
max_memory: int
split_modules: List[str]
all_modules: List[str]


def suggest_checkpoint_location(
traces: List[LayerMemoryTrace], nb_checkpoints: int, num_skipped_layers: int
) -> SuggestedCheckpoints:
"""
Given a trace of a model, collected with or without checkpoint,
return the best places to insert a reset of activation memory.
The names of the returned modules are the boundaries of the
suggested checkpoint_wrapper wrappings
"""

# From the traces, extract how much activation memory
# is generated during the forward pass, layer by layer
visited = set()
modules, allocations = [], []
for t in traces:
if t.is_forward:
name = t.module_name
memory = t.event.memory_activations
if name not in visited:
visited.add(name)
modules.append(name)
allocations.append(memory)

# remove the stem part
modules = modules[num_skipped_layers:]
allocations = allocations[num_skipped_layers:]

# Compute the best positions to reset the memory
max_memory, reset_indices = find_best_reset_points(
allocations, nb_checkpoints=nb_checkpoints
)

# Then map it back to module names
return SuggestedCheckpoints(
max_memory=max_memory,
split_modules=[modules[i] for i in reset_indices],
all_modules=modules,
)


def compare_memory_traces_in_plot(
memory_traces_by_job: Dict[str, List[LayerMemoryTrace]],
figsize: Tuple[int, int] = (16, 20),
Expand Down
22 changes: 22 additions & 0 deletions vissl/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import shutil
import tempfile
from contextlib import contextmanager
from dataclasses import dataclass
from typing import List, Tuple

import torch
Expand Down Expand Up @@ -51,6 +52,27 @@ def with_temp_files(count: int):
os.close(t[0])


@dataclass
class TestTimer:
elapsed_time_ms: int


@contextmanager
def with_timing(name: str):
"""
Test utilities for basic performance tests
"""
test_timer = TestTimer(elapsed_time_ms=0)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
yield test_timer
end_event.record()
torch.cuda.synchronize() # Wait for the events to be recorded!
test_timer.elapsed_time_ms = start_event.elapsed_time(end_event)
print(name, ":", test_timer.elapsed_time_ms, "ms")


def gpu_test(gpu_count: int = 1):
"""
Annotation for GPU tests, skipping the test if the
Expand Down

0 comments on commit 28de28e

Please sign in to comment.