diff --git a/extensions/chamfer_dist/test.py b/extensions/chamfer_dist/test.py index cf45ce3..edf53ba 100644 --- a/extensions/chamfer_dist/test.py +++ b/extensions/chamfer_dist/test.py @@ -2,7 +2,7 @@ # @Author: Haozhe Xie # @Date: 2019-12-10 10:38:01 # @Last Modified by: Haozhe Xie -# @Last Modified time: 2019-12-20 12:44:41 +# @Last Modified time: 2019-12-26 14:21:36 # @Email: cshzxie@gmail.com # # Note: @@ -11,14 +11,22 @@ import os import sys import torch +import unittest from torch.autograd import gradcheck sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir))) from extensions.chamfer_dist import ChamferFunction -x = torch.rand(4, 64, 3).double() -y = torch.rand(4, 128, 3).double() -x.requires_grad = True -y.requires_grad = True -print(gradcheck(ChamferFunction.apply, [x.cuda(), y.cuda()])) + +class ChamferDistanceTestCase(unittest.TestCase): + def test_chamfer_dist(self): + x = torch.rand(4, 64, 3).double() + y = torch.rand(4, 128, 3).double() + x.requires_grad = True + y.requires_grad = True + print(gradcheck(ChamferFunction.apply, [x.cuda(), y.cuda()])) + + +if __name__ == '__main__': + unittest.main() diff --git a/extensions/gridding/test.py b/extensions/gridding/test.py index aca29c2..f47cf6b 100644 --- a/extensions/gridding/test.py +++ b/extensions/gridding/test.py @@ -2,7 +2,7 @@ # @Author: Haozhe Xie # @Date: 2019-12-10 10:48:55 # @Last Modified by: Haozhe Xie -# @Last Modified time: 2019-12-20 12:44:19 +# @Last Modified time: 2019-12-26 14:20:42 # @Email: cshzxie@gmail.com # # Note: @@ -11,24 +11,35 @@ import os import sys import torch +import unittest from torch.autograd import gradcheck sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir))) from extensions.gridding import GriddingFunction, GriddingReverseFunction -x = torch.rand(2, 4, 4, 4) -x.requires_grad = True -gradcheck(GriddingReverseFunction.apply, [4, x.double().cuda()]) -x = torch.rand(4, 8, 8, 8) -x.requires_grad = True -gradcheck(GriddingReverseFunction.apply, [8, x.double().cuda()]) +class GriddingTestCase(unittest.TestCase): + def test_gridding_reverse_function_4(self): + x = torch.rand(2, 4, 4, 4) + x.requires_grad = True + self.assertTrue(gradcheck(GriddingReverseFunction.apply, [4, x.double().cuda()])) -x = torch.rand(1, 16, 16, 16) -x.requires_grad = True -gradcheck(GriddingReverseFunction.apply, [16, x.double().cuda()]) + def test_gridding_reverse_function_8(self): + x = torch.rand(4, 8, 8, 8) + x.requires_grad = True + self.assertTrue(gradcheck(GriddingReverseFunction.apply, [8, x.double().cuda()])) -y = torch.rand(1, 32, 3) -y.requires_grad = True -gradcheck(GriddingFunction.apply, [y.double().cuda()]) + def test_gridding_reverse_function_16(self): + x = torch.rand(1, 16, 16, 16) + x.requires_grad = True + self.assertTrue(gradcheck(GriddingReverseFunction.apply, [16, x.double().cuda()])) + + def test_gridding_function_32pts(self): + x = torch.rand(1, 32, 3) + x.requires_grad = True + self.assertTrue(gradcheck(GriddingFunction.apply, [x.double().cuda()])) + + +if __name__ == '__main__': + unittest.main()