Skip to content

Commit

Permalink
Grad clip for parameters on different devices (pytorch#9302)
Browse files Browse the repository at this point in the history
Summary:
I'm trying to write a multi-gpu network by pipelining some layers onto different GPUs. However, the current gradient clip requires all the parameters to locate in the same device.

The overhead of CUDA launch is reduced since the scalar calculation is performed on CPU, but it introduces extra data transfers.

No performance regression is observed by running the following snippet:
```python
import time

import torch

module = torch.nn.Sequential(
    torch.nn.LSTM(1024, 1024),
    torch.nn.LSTM(256, 256),
    torch.nn.Linear(100, 10000),
).cuda()

torch.nn.utils.clip_grad_norm_(module.parameters(), 1)
torch.cuda.synchronize()
start = time.time()
for _ in range(1000):
    torch.nn.utils.clip_grad_norm_(module.parameters(), 1)
torch.cuda.synchronize()
time_elapse = time.time() - start
print('{} ms per clip'.format(time_elapse))
```
Pull Request resolved: pytorch#9302

Differential Revision: D8781551

Pulled By: soumith

fbshipit-source-id: 9d76d01fe0531927f770a16b9523872a7e08e927
  • Loading branch information
Stonesjtu authored and facebook-github-bot committed Jul 10, 2018
1 parent 1597fc5 commit 89c2b50
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions torch/nn/utils/clip_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ def clip_grad_norm_(parameters, max_norm, norm_type=2):
total_norm = 0
for p in parameters:
param_norm = p.grad.data.norm(norm_type)
total_norm += param_norm ** norm_type
total_norm += param_norm.item() ** norm_type
total_norm = total_norm ** (1. / norm_type)
clip_coef = max_norm / (total_norm + 1e-6)
if clip_coef < 1:
for p in parameters:
p.grad.data.mul_(clip_coef.item())
p.grad.data.mul_(clip_coef)
return total_norm


Expand Down

0 comments on commit 89c2b50

Please sign in to comment.