Skip to content

Commit

Permalink
[Fix] fix issue related bugs (open-mmlab#161)
Browse files Browse the repository at this point in the history
* [Fix] fix version check of pytorch

* [Fix] fix data format in simclr
  • Loading branch information
fangyixiao18 authored Dec 28, 2021
1 parent e96d81e commit f06c3c3
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
2 changes: 1 addition & 1 deletion mmselfsup/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def build_dataloader(dataset,
worker_init_fn, num_workers=num_workers, rank=rank,
seed=seed) if seed is not None else None

if digit_version(torch.__version__) >= digit_version('1.7.0'):
if digit_version(torch.__version__) >= digit_version('1.8.0'):
kwargs['persistent_workers'] = persistent_workers

if kwargs.get('prefetch') is not None:
Expand Down
4 changes: 3 additions & 1 deletion mmselfsup/models/algorithms/simclr.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ def forward_train(self, img, **kwargs):
dict[str, Tensor]: A dictionary of loss components.
"""
assert isinstance(img, list)
img = torch.cat(img)
img = torch.stack(img, 1)
img = img.reshape(
(img.size(0) * 2, img.size(2), img.size(3), img.size(4)))
x = self.extract_feat(img) # 2n
z = self.neck(x)[0] # (2n)xd
z = z / (torch.norm(z, p=2, dim=1, keepdim=True) + 1e-10)
Expand Down

0 comments on commit f06c3c3

Please sign in to comment.