Skip to content

Commit

Permalink
Merge pull request #41 from sp-nitech/learnable_pqmf
Browse files Browse the repository at this point in the history
Learnable PQMF
  • Loading branch information
takenori-y committed Jun 20, 2023
2 parents 9156b3e + b46fa82 commit a5812e5
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 4 deletions.
11 changes: 9 additions & 2 deletions diffsptk/core/ipqmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,15 @@ class InversePseudoQuadratureMirrorFilterBanks(nn.Module):
alpha : float > 0 [scalar]
Stopband attenuation in dB.
learnable : bool [scalar]
Whether to make filter-bank coefficients learnable.
**kwargs : additional keyword arguments
Parameters to find optimal filter-bank coefficients.
"""

def __init__(self, n_band, filter_order, alpha=100, **kwargs):
def __init__(self, n_band, filter_order, alpha=100, learnable=False, **kwargs):
super(InversePseudoQuadratureMirrorFilterBanks, self).__init__()

assert 1 <= n_band
Expand All @@ -59,7 +62,11 @@ def __init__(self, n_band, filter_order, alpha=100, **kwargs):
warnings.warn("Failed to find PQMF coefficients")
filters = np.expand_dims(filters, 0)
filters = np.flip(filters, 2).copy()
self.register_buffer("filters", numpy_to_torch(filters))
filters = numpy_to_torch(filters)
if learnable:
self.filters = nn.Parameter(filters)
else:
self.register_buffer("filters", filters)

# Make padding module.
if filter_order % 2 == 0:
Expand Down
11 changes: 9 additions & 2 deletions diffsptk/core/pqmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,15 @@ class PseudoQuadratureMirrorFilterBanks(nn.Module):
alpha : float > 0 [scalar]
Stopband attenuation in dB.
learnable : bool [scalar]
Whether to make filter-bank coefficients learnable.
**kwargs : additional keyword arguments
Parameters to find optimal filter-bank coefficients.
"""

def __init__(self, n_band, filter_order, alpha=100, **kwargs):
def __init__(self, n_band, filter_order, alpha=100, learnable=False, **kwargs):
super(PseudoQuadratureMirrorFilterBanks, self).__init__()

assert 1 <= n_band
Expand All @@ -174,7 +177,11 @@ def __init__(self, n_band, filter_order, alpha=100, **kwargs):
warnings.warn("Failed to find PQMF coefficients")
filters = np.expand_dims(filters, 1)
filters = np.flip(filters, 2).copy()
self.register_buffer("filters", numpy_to_torch(filters))
filters = numpy_to_torch(filters)
if learnable:
self.filters = nn.Parameter(filters)
else:
self.register_buffer("filters", filters)

# Make padding module.
if filter_order % 2 == 0:
Expand Down
5 changes: 5 additions & 0 deletions tests/test_ipqmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,8 @@ def test_compatibility(device, a, M, tau=0.01, eps=0.01, K=4, T=20):
)

U.check_differentiable(device, ipqmf, [K, T], opt={"keepdim": False})


def test_learnable(K=4, M=10, T=20):
ipqmf = diffsptk.IPQMF(K, M, learnable=True)
U.check_learnable(ipqmf, (K, T))
5 changes: 5 additions & 0 deletions tests/test_pqmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,8 @@ def test_compatibility(device, a, M, tau=0.01, eps=0.01, K=4, T=20):
def test_various_shape(K=4, M=10, T=20):
pqmf = diffsptk.PQMF(K, M)
U.check_various_shape(pqmf, [(T,), (1, T), (1, 1, T)])


def test_learnable(K=4, M=10, T=20):
pqmf = diffsptk.PQMF(K, M, learnable=True)
U.check_learnable(pqmf, (T,))
21 changes: 21 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,24 @@ def check_various_shape(module, shapes):
target = y
else:
assert torch.allclose(y, target)


def check_learnable(module, shape):
params_before = []
for p in module.parameters():
params_before.append(p.clone())

optimizer = torch.optim.SGD(module.parameters(), lr=0.01)
x = torch.randn(*shape)
y = module(x)
optimizer.zero_grad()
loss = y.mean()
loss.backward()
optimizer.step()

params_after = []
for p in module.parameters():
params_after.append(p.clone())

for pb, pa in zip(params_before, params_after):
assert not torch.allclose(pb, pa)

0 comments on commit a5812e5

Please sign in to comment.