Skip to content

Commit

Permalink
fix commitment loss
Browse files Browse the repository at this point in the history
  • Loading branch information
adefossez committed Dec 11, 2023
1 parent 89baee7 commit 62c26ae
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 2 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ AudioCraft requires Python 3.9, PyTorch 2.0.0. To install AudioCraft, you can ru
```shell
# Best to make sure you have torch installed first, in particular before installing xformers.
# Don't run this if you already have PyTorch installed.
python -m pip install 'torch>=2.0'
python -m pip install 'torch==2.1.0'
# You might need the following before trying to install the packages
python -m pip install setuptools wheel
# Then proceed to one of the following
python -m pip install -U audiocraft # stable release
python -m pip install -U git+https://git@github.com/facebookresearch/audiocraft#egg=audiocraft # bleeding edge
Expand Down
5 changes: 5 additions & 0 deletions audiocraft/quantization/core_vq.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,11 +371,16 @@ def forward(self, x, n_q: tp.Optional[int] = None):

for i, layer in enumerate(self.layers[:n_q]):
quantized, indices, loss = layer(residual)
quantized = quantized.detach()
residual = residual - quantized
quantized_out = quantized_out + quantized
all_indices.append(indices)
all_losses.append(loss)

if self.training:
# Solving subtle bug with STE and RVQ: https://github.com/facebookresearch/encodec/issues/25
quantized_out = x + (quantized_out - x).detach()

out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
return quantized_out, out_indices, out_losses

Expand Down
4 changes: 3 additions & 1 deletion tests/quantization/test_vq.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
class TestResidualVectorQuantizer:

def test_rvq(self):
x = torch.randn(1, 16, 2048)
x = torch.randn(1, 16, 2048, requires_grad=True)
vq = ResidualVectorQuantizer(n_q=8, dimension=16, bins=8)
res = vq(x, 1.)
assert res.x.shape == torch.Size([1, 16, 2048])
res.x.sum().backward()
assert torch.allclose(x.grad.data, torch.ones(1))

0 comments on commit 62c26ae

Please sign in to comment.