Skip to content

Commit

Permalink
Update emd_module.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Colin97 committed Apr 8, 2022
1 parent c5f23a9 commit dc6450f
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions emd/emd_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,11 @@ def forward(self, input1, input2, eps, iters):
return emdFunction.apply(input1, input2, eps, iters)

def test_emd():
x1 = torch.rand(20, 8192, 3).cuda()
x1 = torch.rand(20, 8192, 3).cuda() # please normalize your point cloud to [0, 1]
x2 = torch.rand(20, 8192, 3).cuda()
emd = emdModule()
start_time = time.perf_counter()
dis, assigment = emd(x1, x2, 0.05, 3000)
dis, assigment = emd(x1, x2, 0.002, 10000) # 0.005, 50 for training
print("Input_size: ", x1.shape)
print("Runtime: %lfs" % (time.perf_counter() - start_time))
print("EMD: %lf" % np.sqrt(dis.cpu()).mean())
Expand All @@ -95,4 +95,4 @@ def test_emd():
print("Verified EMD: %lf" % np.sqrt(d.cpu().sum(-1)).mean())

#test_emd()


0 comments on commit dc6450f

Please sign in to comment.