Skip to content

Commit

Permalink
improve coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
takenori-y committed Feb 13, 2023
1 parent 45a8659 commit fa45462
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
5 changes: 2 additions & 3 deletions diffsptk/core/lbg.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,8 @@ def forward(self, x):
self.logger.info(f"K={curr_codebook_size} {n:5d}: {distance:g}")

# Check convergence.
if distance == 0:
break
if 0 < n and (prev_distance - distance).abs() / distance < self.eps:
diff = (prev_distance - distance).abs()
if n and diff / (distance + 1e-16) < self.eps:
break
prev_distance = distance

Expand Down
12 changes: 10 additions & 2 deletions tests/test_lbg.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,9 @@

@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_compatibility(device, M=1, K=4, B=10, n_iter=10):
lbg = diffsptk.LindeBuzoGrayAlgorithm(M, K, n_iter=n_iter)

torch.manual_seed(1234)
torch.cuda.manual_seed(1234)
lbg = diffsptk.LindeBuzoGrayAlgorithm(M, K, n_iter=n_iter)

tmp1 = "lbg.tmp1"
tmp2 = "lbg.tmp2"
Expand All @@ -52,3 +51,12 @@ def test_compatibility(device, M=1, K=4, B=10, n_iter=10):
dx=M + 1,
rtol=0.1,
)


def test_min_data_per_cluster(M=1, K=4, B=10):
torch.manual_seed(1234)
x = torch.randn(B, M + 1)
lbg = diffsptk.LindeBuzoGrayAlgorithm(
M, K, n_iter=10, min_data_per_cluster=int(B * 0.9)
)
_, _ = lbg(x)

0 comments on commit fa45462

Please sign in to comment.