-
Notifications
You must be signed in to change notification settings - Fork 43
/
dist_chamfer_3D.py
81 lines (66 loc) · 2.69 KB
/
dist_chamfer_3D.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from torch import nn
from torch.autograd import Function
import torch
import importlib
import os
chamfer_found = importlib.find_loader("chamfer_3D") is not None
if not chamfer_found:
## Cool trick from https://github.com/chrdiller
print("Jitting Chamfer 3D")
cur_path = os.path.dirname(os.path.abspath(__file__))
build_path = cur_path.replace('chamfer3D', 'tmp')
os.makedirs(build_path, exist_ok=True)
from torch.utils.cpp_extension import load
chamfer_3D = load(name="chamfer_3D",
sources=[
"/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]),
"/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer3D.cu"]),
], build_directory=build_path)
print("Loaded JIT 3D CUDA chamfer distance")
else:
import chamfer_3D
print("Loaded compiled 3D CUDA chamfer distance")
# Chamfer's distance module @thibaultgroueix
# GPU tensors only
class chamfer_3DFunction(Function):
@staticmethod
def forward(ctx, xyz1, xyz2):
batchsize, n, dim = xyz1.size()
assert dim==3, "Wrong last dimension for the chamfer distance 's input! Check with .size()"
_, m, dim = xyz2.size()
assert dim==3, "Wrong last dimension for the chamfer distance 's input! Check with .size()"
device = xyz1.device
device = xyz1.device
dist1 = torch.zeros(batchsize, n)
dist2 = torch.zeros(batchsize, m)
idx1 = torch.zeros(batchsize, n).type(torch.IntTensor)
idx2 = torch.zeros(batchsize, m).type(torch.IntTensor)
dist1 = dist1.to(device)
dist2 = dist2.to(device)
idx1 = idx1.to(device)
idx2 = idx2.to(device)
torch.cuda.set_device(device)
chamfer_3D.forward(xyz1, xyz2, dist1, dist2, idx1, idx2)
ctx.save_for_backward(xyz1, xyz2, idx1, idx2)
return dist1, dist2, idx1, idx2
@staticmethod
def backward(ctx, graddist1, graddist2, gradidx1, gradidx2):
xyz1, xyz2, idx1, idx2 = ctx.saved_tensors
graddist1 = graddist1.contiguous()
graddist2 = graddist2.contiguous()
device = graddist1.device
gradxyz1 = torch.zeros(xyz1.size())
gradxyz2 = torch.zeros(xyz2.size())
gradxyz1 = gradxyz1.to(device)
gradxyz2 = gradxyz2.to(device)
chamfer_3D.backward(
xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2
)
return gradxyz1, gradxyz2
class chamfer_3DDist(nn.Module):
def __init__(self):
super(chamfer_3DDist, self).__init__()
def forward(self, input1, input2):
input1 = input1.contiguous()
input2 = input2.contiguous()
return chamfer_3DFunction.apply(input1, input2)