Skip to content

Commit

Permalink
Use the 'unittest' package to perform gradient check.
Browse files Browse the repository at this point in the history
  • Loading branch information
hzxie committed Dec 26, 2019
1 parent 2fdb917 commit db512ab
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 19 deletions.
20 changes: 14 additions & 6 deletions extensions/chamfer_dist/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# @Author: Haozhe Xie
# @Date: 2019-12-10 10:38:01
# @Last Modified by: Haozhe Xie
# @Last Modified time: 2019-12-20 12:44:41
# @Last Modified time: 2019-12-26 14:21:36
# @Email: cshzxie@gmail.com
#
# Note:
Expand All @@ -11,14 +11,22 @@
import os
import sys
import torch
import unittest

from torch.autograd import gradcheck

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir)))
from extensions.chamfer_dist import ChamferFunction

x = torch.rand(4, 64, 3).double()
y = torch.rand(4, 128, 3).double()
x.requires_grad = True
y.requires_grad = True
print(gradcheck(ChamferFunction.apply, [x.cuda(), y.cuda()]))

class ChamferDistanceTestCase(unittest.TestCase):
def test_chamfer_dist(self):
x = torch.rand(4, 64, 3).double()
y = torch.rand(4, 128, 3).double()
x.requires_grad = True
y.requires_grad = True
print(gradcheck(ChamferFunction.apply, [x.cuda(), y.cuda()]))


if __name__ == '__main__':
unittest.main()
37 changes: 24 additions & 13 deletions extensions/gridding/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# @Author: Haozhe Xie
# @Date: 2019-12-10 10:48:55
# @Last Modified by: Haozhe Xie
# @Last Modified time: 2019-12-20 12:44:19
# @Last Modified time: 2019-12-26 14:20:42
# @Email: cshzxie@gmail.com
#
# Note:
Expand All @@ -11,24 +11,35 @@
import os
import sys
import torch
import unittest

from torch.autograd import gradcheck

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir)))
from extensions.gridding import GriddingFunction, GriddingReverseFunction

x = torch.rand(2, 4, 4, 4)
x.requires_grad = True
gradcheck(GriddingReverseFunction.apply, [4, x.double().cuda()])

x = torch.rand(4, 8, 8, 8)
x.requires_grad = True
gradcheck(GriddingReverseFunction.apply, [8, x.double().cuda()])
class GriddingTestCase(unittest.TestCase):
def test_gridding_reverse_function_4(self):
x = torch.rand(2, 4, 4, 4)
x.requires_grad = True
self.assertTrue(gradcheck(GriddingReverseFunction.apply, [4, x.double().cuda()]))

x = torch.rand(1, 16, 16, 16)
x.requires_grad = True
gradcheck(GriddingReverseFunction.apply, [16, x.double().cuda()])
def test_gridding_reverse_function_8(self):
x = torch.rand(4, 8, 8, 8)
x.requires_grad = True
self.assertTrue(gradcheck(GriddingReverseFunction.apply, [8, x.double().cuda()]))

y = torch.rand(1, 32, 3)
y.requires_grad = True
gradcheck(GriddingFunction.apply, [y.double().cuda()])
def test_gridding_reverse_function_16(self):
x = torch.rand(1, 16, 16, 16)
x.requires_grad = True
self.assertTrue(gradcheck(GriddingReverseFunction.apply, [16, x.double().cuda()]))

def test_gridding_function_32pts(self):
x = torch.rand(1, 32, 3)
x.requires_grad = True
self.assertTrue(gradcheck(GriddingFunction.apply, [x.double().cuda()]))


if __name__ == '__main__':
unittest.main()

0 comments on commit db512ab

Please sign in to comment.