Skip to content
This repository has been archived by the owner on Mar 19, 2024. It is now read-only.

[feature] (activation stats): ability to track the statistics of each layer in the model when training #264

Closed

Conversation

QuentinDuval
Copy link
Contributor

@QuentinDuval QuentinDuval commented Mar 30, 2021

Tracking activation statistics during training

This PR features a new monitoring utility, plugged in VISSL via the Tensorboard hook, which allows to capture the output of the "leaf modules" (not all modules) and compute mean and spread statistics on the output of each layer, and track them in tensorboard.

  • @QuentinDuval to re-run performance tests after review comments and refactoring on SwAV
  • @QuentinDuval to try to add some quick performance tests in the unit tests

Output example:

Screenshot 2021-04-07 at 08 52 10

Description

The configuration has been updated with the following option (by default disabled):

  MONITORING:
    # At which frequency do we monitor statistics on the activations:
    # - 0 means that we do not monitor statistics
    # - N > 0 means we monitor every N iterations
    MONITOR_ACTIVATION_STATISTICS: 0

Turn on the option with the following hydra override: config.MONITORING.MONITOR_ACTIVATION_STATISTICS=50 to gather the statistics on all activations every 50 iterations.

NOTE: This option requires the tensorboard hook to be enabled to take effect.

Performance impacts: The collect of statistics makes use of the following optimisation: for feature maps, we only compute the statistics on the central feature, which requires less compute and memory than on the full feature map and still exercises all the weights of the BN or Conv2D layer. With this optimisation, the impact in terms of memory is negligible and the impact in terms of runtime is about 1.4% on SimCLR, when the frequency of collection is every 50 iterations.

Further improvements

  • Support for other backend than just Tensorboard: ideally, we should be able to dump raw data / images
  • Profile what takes time when computing statistics and optimise it further if possible
  • Adding an "alerting" feature which varies the speed of tracking based on detected divergence

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 30, 2021
@prigoyal
Copy link
Contributor

Hi @QuentinDuval , all the points you raised are great. I'd propose that we meet over VC to discuss these.

certainly interested in the fixes, feel free to make PR :)

Copy link
Contributor

@prigoyal prigoyal left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did a high level design overview of this and it looks great to me! :) early feedback but it is heading in the right direction :) great work @QuentinDuval :)

vissl/config/defaults.yaml Show resolved Hide resolved
Copy link
Contributor

@prigoyal prigoyal left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks great to me @QuentinDuval , no comments code wise. Next steps as we discussed , and then it's good to go :)

tests/test_activation_statistics.py Show resolved Hide resolved
h2 = m.register_forward_hook(self._create_post_forward_hook(name))
self._hooks.extend([h1, h2])

def stop(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice :)

@facebook-github-bot
Copy link
Contributor

@QuentinDuval has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@prigoyal prigoyal left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is super awesome work @QuentinDuval :)

  1. quick clarification: leaf modules -> doesn't mean just head is tracked right? from the code, it seems like trunk+head both are tracked but figures in test plan are for the heads so just wanted to double check :)

  2. we should insert the copyright headers everywhere -> blocker for landing this PR

@QuentinDuval
Copy link
Contributor Author

this is super awesome work @QuentinDuval :)

  1. quick clarification: leaf modules -> doesn't mean just head is tracked right? from the code, it seems like trunk+head both are tracked but figures in test plan are for the heads so just wanted to double check :)
  2. we should insert the copyright headers everywhere -> blocker for landing this PR

Hi @prigoyal :)

By leaf module, I meant modules that are not encapsulating other modules. For instance, nn.Sequential(nn.Linear(...), nn.ReLU(...)): the leaf modules are nn.Linear() and nn.ReLU() and we ignore nn.Sequential. This also means we will ignore things like Bottleneck (which only contains modules) but might also ignore some interesting modules now that I think about it (*).

Indeed, otherwise, all modules are being monitored if they are in training mode: we ignore modules that are not trained, like frozen modules when doing a linear evaluation, for there is not much to monitor in that case.

(*) I think we can improve that in later PR: ignore modules who do not have parameters on their own but only parameters in their children, or simply hardcode what we want to ignore.

@facebook-github-bot
Copy link
Contributor

@QuentinDuval has updated the pull request. You must reimport the pull request before landing.

@QuentinDuval
Copy link
Contributor Author

this is super awesome work @QuentinDuval :)

  1. quick clarification: leaf modules -> doesn't mean just head is tracked right? from the code, it seems like trunk+head both are tracked but figures in test plan are for the heads so just wanted to double check :)
  2. we should insert the copyright headers everywhere -> blocker for landing this PR

I added the copyrights in d85434b 👍 Good catch !

@facebook-github-bot
Copy link
Contributor

@QuentinDuval has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

… layer in the model when training - greatly improving performance by sampling feature maps at the center (so that each parameter is used) and computing the maximum spread instead of the min and max
… layer in the model when training - decreasing GPU memory usage when using the "sample feature map" flag
… layer in the model when training - reset the _prev_module_name in post_forward_hook (avoid potential future bugs)
… layer in the model when training - renaming and documentation
… layer in the model when training - renaming and documentation
… layer in the model when training - renaming and documentation
… layer in the model when training - bug fixing: make stop idempotent
… layer in the model when training - add missing copyright header
@facebook-github-bot
Copy link
Contributor

@QuentinDuval has updated the pull request. You must reimport the pull request before landing.

@facebook-github-bot
Copy link
Contributor

@QuentinDuval has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@QuentinDuval merged this pull request in 4c245c8.

facebook-github-bot pushed a commit that referenced this pull request May 3, 2022
Summary: Pull Request resolved: fairinternal/ssl_scaling#264

Reviewed By: mannatsingh

Differential Revision: D35579626

Pulled By: QuentinDuval

fbshipit-source-id: 42d25b576ed8451ddd6bc500fdb8dc39c072bb0e
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants