From f06c3c3d98cbcae4b2cc7970c8f06b61f6c97022 Mon Sep 17 00:00:00 2001 From: Yixiao Fang <36138628+fangyixiao18@users.noreply.github.com> Date: Tue, 28 Dec 2021 22:06:19 +0800 Subject: [PATCH] [Fix] fix issue related bugs (#161) * [Fix] fix version check of pytorch * [Fix] fix data format in simclr --- mmselfsup/datasets/builder.py | 2 +- mmselfsup/models/algorithms/simclr.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/mmselfsup/datasets/builder.py b/mmselfsup/datasets/builder.py index fdb868326..3c2594c93 100644 --- a/mmselfsup/datasets/builder.py +++ b/mmselfsup/datasets/builder.py @@ -103,7 +103,7 @@ def build_dataloader(dataset, worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None - if digit_version(torch.__version__) >= digit_version('1.7.0'): + if digit_version(torch.__version__) >= digit_version('1.8.0'): kwargs['persistent_workers'] = persistent_workers if kwargs.get('prefetch') is not None: diff --git a/mmselfsup/models/algorithms/simclr.py b/mmselfsup/models/algorithms/simclr.py index d4795bf1f..071b3c7fa 100644 --- a/mmselfsup/models/algorithms/simclr.py +++ b/mmselfsup/models/algorithms/simclr.py @@ -69,7 +69,9 @@ def forward_train(self, img, **kwargs): dict[str, Tensor]: A dictionary of loss components. """ assert isinstance(img, list) - img = torch.cat(img) + img = torch.stack(img, 1) + img = img.reshape( + (img.size(0) * 2, img.size(2), img.size(3), img.size(4))) x = self.extract_feat(img) # 2n z = self.neck(x)[0] # (2n)xd z = z / (torch.norm(z, p=2, dim=1, keepdim=True) + 1e-10)