forked from facebookresearch/vissl
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Memory profiling (facebookresearch#100)
Summary: # Layer by layer memory profiling A first version of the memory profiling, tracking the memory used through the forward/backward passes, with a breakdown of the memory dedicated to activations (issue fairinternal/ssl_scaling#97). - [x] Define the test plan - [x] Provide example curves and data output - [x] Run on FSDP vs DDP - [x] Run on FSDP with or without checkpointing ## Using the feature Just add `cfg.PROFILING.TRACK_BY_LAYER_MEMORY=True` in the command line when running a job to track the memory usage, layer by layer, during both the forward and backward. Further configuration is available to chose: - which rank is monitored - for how many iterations - starting from which iteration Pull Request resolved: fairinternal/ssl_scaling#100 Test Plan: The feature comes with its own set of unit tests ## Example outputs The output directory will contain the following files for each rank and iteration monitored: ``` memory_rank_0_iteration_0.json memory_rank_0_iteration_0.jpg ``` The JSON file contains the raw data, while the JPG file provides an overview of what happening in terms of memory: <img width="1047" alt="Screenshot 2021-04-19 at 11 26 06" src="https://user-images.githubusercontent.com/7412790/115261974-19376780-a102-11eb-838c-688d807094d3.png"> Reviewed By: prigoyal Differential Revision: D27977734 Pulled By: QuentinDuval fbshipit-source-id: 4000f84e418afecb7c02dee5c5add260a04046ba
- Loading branch information
1 parent
4adeacf
commit 4708348
Showing
7 changed files
with
598 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | ||
|
||
import contextlib | ||
import unittest | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torchvision.models as models | ||
from vissl.utils.layer_memory_tracking import LayerwiseMemoryTracker | ||
|
||
|
||
class TestLayerMemoryTracking(unittest.TestCase): | ||
|
||
@unittest.skipIf(not torch.cuda.is_available(), "Test requires at least 1 GPU") | ||
def test_memory_tracking(self): | ||
|
||
# Create a model with a hierarchy of modules | ||
torch.manual_seed(0) | ||
model = nn.Sequential( | ||
nn.Sequential( | ||
nn.Conv2d(3, 64, kernel_size=3), | ||
nn.BatchNorm2d(64), | ||
nn.ReLU(inplace=False), | ||
nn.AdaptiveAvgPool2d(output_size=(1, 1)), | ||
), | ||
nn.Flatten(start_dim=1), | ||
nn.Sequential(nn.Linear(64, 2), nn.ReLU(inplace=True)), | ||
).cuda() | ||
|
||
# Track a fake forward / backward | ||
tracker = LayerwiseMemoryTracker() | ||
tracker.monitor(model) | ||
x = torch.randn(size=(2, 3, 224, 224)).cuda() | ||
target = torch.LongTensor([0, 1]).cuda() | ||
criterion = nn.CrossEntropyLoss() | ||
criterion(model(x), target).backward() | ||
|
||
# Verify that only leaf modules are tracked | ||
tracked_names = {trace.module_name for trace in tracker.memory_traces} | ||
expected_names = {"0.0", "0.1", "0.2", "0.3", "1", "2.0", "2.1"} | ||
self.assertEqual(expected_names, tracked_names) | ||
|
||
# Verify that memory tracking for ReLU is sound | ||
self.assertEqual( | ||
25233408, | ||
tracker.forward_traces[2].event.memory_activations, | ||
"ReLU(inplace=False) should allocate activations", | ||
) | ||
self.assertEqual( | ||
0, | ||
tracker.forward_traces[6].event.memory_activations, | ||
"ReLU(inplace=True) should NOT allocate activations", | ||
) | ||
|
||
# Verify that overall memory tracking is sound | ||
summary = tracker.summary | ||
self.assertGreaterEqual( | ||
summary.total_forward_allocations, summary.total_activation_allocations | ||
) | ||
|
||
top_act_producers = summary.top_forward_activation_producers[:3] | ||
self.assertEqual("0.0", top_act_producers[0].module_name) | ||
self.assertEqual("0.1", top_act_producers[1].module_name) | ||
self.assertEqual("0.2", top_act_producers[2].module_name) | ||
self.assertEqual(7168, top_act_producers[0].module_params) | ||
self.assertEqual(512, top_act_producers[1].module_params) | ||
self.assertEqual(0, top_act_producers[2].module_params) | ||
for trace in top_act_producers: | ||
self.assertEqual(25233408, trace.event.memory_activations) | ||
|
||
@contextlib.contextmanager | ||
def with_timing(self, name: str): | ||
start_event = torch.cuda.Event(enable_timing=True) | ||
end_event = torch.cuda.Event(enable_timing=True) | ||
start_event.record() | ||
yield | ||
end_event.record() | ||
torch.cuda.synchronize() # Wait for the events to be recorded! | ||
elapsed_time_ms = start_event.elapsed_time(end_event) | ||
print(name, ":", elapsed_time_ms, "ms") | ||
|
||
@unittest.skipIf(not torch.cuda.is_available(), "Test requires at least 1 GPU") | ||
def test_memory_tracking_performance_impact(self): | ||
torch.manual_seed(0) | ||
model = models.resnet18() | ||
with self.with_timing("no_tracking"): | ||
model(torch.randn(size=(1, 3, 224, 224))) | ||
with self.with_timing("with_tracking"): | ||
tracker = LayerwiseMemoryTracker() | ||
tracker.monitor(model) | ||
model(torch.randn(size=(1, 3, 224, 224))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | ||
|
||
import json | ||
import logging | ||
import os | ||
|
||
from classy_vision import tasks | ||
from classy_vision.hooks.classy_hook import ClassyHook | ||
from vissl.config.attr_dict import AttrDict | ||
from vissl.utils.env import get_machine_local_and_dist_rank | ||
from vissl.utils.layer_memory_tracking import LayerwiseMemoryTracker | ||
|
||
|
||
class ProfilingHook(ClassyHook): | ||
""" | ||
Hook used to trigger the profiling features of VISSL | ||
""" | ||
|
||
on_loss_and_meter = ClassyHook._noop | ||
on_forward = ClassyHook._noop | ||
on_backward = ClassyHook._noop | ||
on_step = ClassyHook._noop | ||
on_phase_start = ClassyHook._noop | ||
on_phase_end = ClassyHook._noop | ||
|
||
@staticmethod | ||
def is_enabled(profiling_config: AttrDict): | ||
""" | ||
Returns whether or not the profiler hook should be instantiated: | ||
it should be enabled if any of the profiling options is on | ||
""" | ||
return profiling_config.MEMORY_PROFILING.TRACK_BY_LAYER_MEMORY | ||
|
||
def __init__(self, profiling_config: AttrDict): | ||
super().__init__() | ||
self.output_folder = profiling_config.OUTPUT_FOLDER | ||
self.start_iteration = profiling_config.START_ITERATION | ||
self.end_iteration = ( | ||
profiling_config.START_ITERATION + profiling_config.NUM_ITERATIONS | ||
) | ||
self.dist_rank = get_machine_local_and_dist_rank()[1] | ||
self.enabled = self.dist_rank in profiling_config.PROFILED_RANKS | ||
self.profile_memory = ( | ||
self.enabled and profiling_config.MEMORY_PROFILING.TRACK_BY_LAYER_MEMORY | ||
) | ||
if self.profile_memory: | ||
logging.info(f"Setting up memory tracker for rank {self.dist_rank}...") | ||
self.layer_memory_tracker = LayerwiseMemoryTracker() | ||
|
||
def on_start(self, task: "tasks.ClassyTask") -> None: | ||
""" | ||
Called at the start of training. | ||
""" | ||
if self.profile_memory: | ||
assert ( | ||
task.use_gpu is True | ||
), "Profiling memory usage requires training on GPU" | ||
if self.profile_memory and self.start_iteration == 0: | ||
self.layer_memory_tracker.monitor(task.base_model) | ||
|
||
def on_end(self, task: "tasks.ClassyTask") -> None: | ||
""" | ||
Called at the end of training. | ||
""" | ||
if self.profile_memory: | ||
self.layer_memory_tracker.stop() | ||
|
||
def on_update(self, task: "tasks.ClassyTask") -> None: | ||
""" | ||
Called after parameter update. | ||
""" | ||
if self.profile_memory: | ||
iteration = task.local_iteration_num | ||
self._memory_tracking(iteration, task) | ||
|
||
def _memory_tracking(self, iteration: int, task: "tasks.ClassyTask"): | ||
""" | ||
Handle the memory tracking logic: | ||
- enabling / disabling the tracker depending on the iteration | ||
- dumping the statistics collected in previous iteration | ||
- preparing the tracker for the next iteration | ||
""" | ||
next_iteration = iteration + 1 | ||
|
||
# Dump memory statistics | ||
if self.start_iteration <= iteration < self.end_iteration: | ||
# TODO (prigoyal): figure out how to save when using non-disk backend | ||
image = self.layer_memory_tracker.show_plots(capture=True) | ||
image_name = f"memory_rank_{self.dist_rank}_iteration_{iteration}.jpg" | ||
image.save(os.path.join(self.output_folder, image_name)) | ||
json_name = f"memory_rank_{self.dist_rank}_iteration_{iteration}.json" | ||
with open(json_name, "w") as f: | ||
json_traces = { | ||
"traces": [ | ||
t.to_dict() for t in self.layer_memory_tracker.memory_traces | ||
] | ||
} | ||
json.dump(json_traces, f) | ||
self.layer_memory_tracker.clear_traces() | ||
|
||
# Enable / disable the profiling based on the current iteration | ||
if next_iteration == self.start_iteration: | ||
self.layer_memory_tracker.monitor(task.base_model) | ||
if next_iteration == self.end_iteration: | ||
self.layer_memory_tracker.stop() |
Oops, something went wrong.