Skip to content

Commit

Permalink
VAE updated (fix #340)
Browse files Browse the repository at this point in the history
  • Loading branch information
chrischoy committed Apr 7, 2021
1 parent 5dfb990 commit 2e16531
Show file tree
Hide file tree
Showing 2 changed files with 195 additions and 212 deletions.
44 changes: 20 additions & 24 deletions examples/reconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def resample_mesh(mesh_cad, density=1):
sample_face_idx = np.zeros((n_samples,), dtype=int)
acc = 0
for face_idx, _n_sample in enumerate(n_samples_per_face):
sample_face_idx[acc: acc + _n_sample] = face_idx
sample_face_idx[acc : acc + _n_sample] = face_idx
acc += _n_sample

r = np.random.rand(n_samples, 2)
Expand Down Expand Up @@ -192,8 +192,7 @@ def __init__(self, phase, transform=None, config=None):

self.root = "./ModelNet40"
fnames = glob.glob(os.path.join(self.root, "chair/train/*.off"))
fnames = sorted([os.path.relpath(fname, self.root)
for fname in fnames])
fnames = sorted([os.path.relpath(fname, self.root) for fname in fnames])
self.files = fnames
assert len(self.files) > 0, "No file loaded"
logging.info(
Expand Down Expand Up @@ -297,8 +296,7 @@ def make_data_loader(
parser.add_argument("--weight_decay", type=float, default=1e-4)
parser.add_argument("--num_workers", type=int, default=1)
parser.add_argument("--stat_freq", type=int, default=50)
parser.add_argument("--weights", type=str,
default="modelnet_reconstruction.pth")
parser.add_argument("--weights", type=str, default="modelnet_reconstruction.pth")
parser.add_argument("--load_optimizer", type=str, default="true")
parser.add_argument("--eval", action="store_true")
parser.add_argument("--max_visualization", type=int, default=4)
Expand Down Expand Up @@ -427,22 +425,22 @@ def __init__(self, resolution, in_nchannel=512):
# pruning
self.pruning = ME.MinkowskiPruning()

@torch.no_grad()
def get_target(self, out, target_key, kernel_size=1):
with torch.no_grad():
target = torch.zeros(len(out), dtype=torch.bool, device=out.device)
cm = out.coordinate_manager
strided_target_key = cm.stride(
target_key,
out.tensor_stride[0],
)
kernel_map = cm.kernel_map(
out.coordinate_map_key,
strided_target_key,
kernel_size=kernel_size,
region_type=1,
)
for k, curr_in in kernel_map.items():
target[curr_in[0].long()] = 1
target = torch.zeros(len(out), dtype=torch.bool, device=out.device)
cm = out.coordinate_manager
strided_target_key = cm.stride(
target_key,
out.tensor_stride[0],
)
kernel_map = cm.kernel_map(
out.coordinate_map_key,
strided_target_key,
kernel_size=kernel_size,
region_type=1,
)
for k, curr_in in kernel_map.items():
target[curr_in[0].long()] = 1
return target

def valid_batch_map(self, batch_map):
Expand Down Expand Up @@ -591,8 +589,7 @@ def train(net, dataloader, device, config):
num_layers, loss = len(out_cls), 0
losses = []
for out_cl, target in zip(out_cls, targets):
curr_loss = crit(out_cl.F.squeeze(), target.type(
out_cl.F.dtype).to(device))
curr_loss = crit(out_cl.F.squeeze(), target.type(out_cl.F.dtype).to(device))
losses.append(curr_loss.item())
loss += curr_loss / num_layers

Expand Down Expand Up @@ -702,8 +699,7 @@ def visualize(net, dataloader, device, config):
train(net, dataloader, device, config)
else:
if not os.path.exists(config.weights):
logging.info(
f"Downloaing pretrained weights. This might take a while...")
logging.info(f"Downloaing pretrained weights. This might take a while...")
urllib.request.urlretrieve(
"https://bit.ly/36d9m1n", filename=config.weights
)
Expand Down
Loading

0 comments on commit 2e16531

Please sign in to comment.