diff --git a/main_eval.py b/main_eval.py index 78961da..c1d8599 100644 --- a/main_eval.py +++ b/main_eval.py @@ -424,6 +424,113 @@ def main_draw_poison(args): + +def main_backdoor(args): + + assert args.backbone in BaseMethod._SUPPORTED_BACKBONES + backbone_model = { + "resnet18": resnet18, + "resnet50": resnet50, + "vit_tiny": vit_tiny, + "vit_small": vit_small, + "vit_base": vit_base, + "vit_large": vit_large, + "swin_tiny": swin_tiny, + "swin_small": swin_small, + "swin_base": swin_base, + "swin_large": swin_large, + }[args.backbone] + + # initialize backbone + kwargs = args.backbone_args + cifar = kwargs.pop("cifar", False) + # swin specific + if "swin" in args.backbone and cifar: + kwargs["window_size"] = 4 + + backbone = backbone_model(**kwargs) + if "resnet" in args.backbone: + # remove fc layer + if args.load_linear: + backbone.fc = nn.Linear(backbone.inplanes, args.num_classes) + else: + backbone.fc = nn.Identity() + if cifar: + backbone.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=2, bias=False) + backbone.maxpool = nn.Identity() + + assert ( + args.pretrained_feature_extractor.endswith(".ckpt") + or args.pretrained_feature_extractor.endswith(".pth") + or args.pretrained_feature_extractor.endswith(".pt") + ) + ckpt_path = args.pretrained_feature_extractor + + state = torch.load(ckpt_path)["state_dict"] + + for k in list(state.keys()): + if "encoder" in k: + raise Exception( + "You are using an older checkpoint." + "Either use a new one, or convert it by replacing" + "all 'encoder' occurances in state_dict with 'backbone'" + ) + if "backbone" in k: + state[k.replace("backbone.", "")] = state[k] + if args.load_linear: + if "classifier" in k: + state[k.replace("classifier.", "fc.")] = state[k] + del state[k] + # prepare model + backbone.load_state_dict(state, strict=False) + backbone = backbone.cuda() + backbone.eval() + + pattern = np.array(Image.open('./data/cifar_gaussian_noise.png')) + + train_loader, val_loader, train_dataset, val_dataset = prepare_data_no_aug( + args.dataset, + data_dir=args.data_dir, + train_dir=args.train_dir, + val_dir=args.val_dir, + batch_size=args.batch_size, + num_workers=args.num_workers, + ) + backbone.eval() + + val_features, val_labels = inference(backbone, val_loader) + + alpha = args.trigger_alpha + val_dataset.data = ((1-alpha) * val_dataset.data + alpha * pattern).astype(np.uint8) + poison_val_dataset = transform_dataset('cifar', val_dataset, pattern, 1, 0.2) + + poison_val_loader = torch.utils.data.DataLoader( + poison_val_dataset, + batch_size=100, + num_workers=1, + pin_memory=False, + shuffle=False, + drop_last=False, + ) + + poison_val_features = inference(backbone, poison_val_loader)[0] + + # train_features = nn.functional.normalize(train_features, dim=1) + # train_images, train_labels = train_dataset.data, np.array(train_dataset.targets) + # device = torch.device('cuda') + total_correct = 0 + total_loss = 0.0 + with torch.no_grad(): + for i, (images, labels) in enumerate(val_loader): + images, labels = images.cuda(), labels.cuda() + output = backbone(images) + total_loss += nn.functional.cross_entropy(output, labels).item() + pred = output.data.max(1)[1] + total_correct += pred.eq(labels.data.view_as(pred)).sum() + loss = total_loss / len(val_loader) + acc = float(total_correct) / len(val_loader.dataset) + print(acc) + if __name__ == "__main__": args = parse_args_linear() # if args.pretrain_method == 'clb': @@ -432,7 +539,8 @@ def main_draw_poison(args): # poison_data = main_clb(args) # else: # main_tSNE(args) - main_draw_trigger(args) + # main_draw_trigger(args) + main_backdoor(args) # main_draw_poison(args) # args = parse_args_linear() diff --git a/solo/losses/__init__.py b/solo/losses/__init__.py index 4c76a8e..a25fb6e 100644 --- a/solo/losses/__init__.py +++ b/solo/losses/__init__.py @@ -30,6 +30,7 @@ from solo.losses.vibcreg import vibcreg_loss_func from solo.losses.vicreg import vicreg_loss_func from solo.losses.wmse import wmse_loss_func +from solo.losses.simclr import w_simclr_loss_func __all__ = [ "barlow_loss_func", @@ -45,4 +46,5 @@ "vibcreg_loss_func", "vicreg_loss_func", "wmse_loss_func", + "w_simclr_loss_func", ] diff --git a/solo/losses/simclr.py b/solo/losses/simclr.py index 80345a5..cc4a5d7 100644 --- a/solo/losses/simclr.py +++ b/solo/losses/simclr.py @@ -57,3 +57,40 @@ def simclr_loss_func( neg = torch.sum(sim * neg_mask, 1) loss = -(torch.mean(torch.log(pos / (pos + neg)))) return loss + + +def w_simclr_loss_func( + z: torch.Tensor, indexes: torch.Tensor, temperature: float = 0.1 +) -> torch.Tensor: + """Computes SimCLR's loss given batch of projected features z + from different views, a positive boolean mask of all positives and + a negative boolean mask of all negatives. + + Args: + z (torch.Tensor): (N*views) x D Tensor containing projected features from the views. + indexes (torch.Tensor): unique identifiers for each crop (unsupervised) + or targets of each crop (supervised). + + Return: + torch.Tensor: SimCLR loss. + """ + + z = F.normalize(z, dim=-1) + gathered_z = gather(z) + + sim = torch.exp(torch.einsum("if, jf -> ij", z, gathered_z) / temperature) + + gathered_indexes = gather(indexes) + + indexes = indexes.unsqueeze(0) + gathered_indexes = gathered_indexes.unsqueeze(0) + # positives + pos_mask = indexes.t() == gathered_indexes + pos_mask[:, z.size(0) * get_rank() :].fill_diagonal_(0) + # negatives + neg_mask = indexes.t() != gathered_indexes + + pos = torch.sum(sim * pos_mask, 1) + neg = torch.sum(sim * neg_mask, 1) + loss = -(torch.mean(torch.log(pos / (pos + neg)))) + return loss diff --git a/solo/methods/rot.py b/solo/methods/rot.py index e07a253..a8ac078 100644 --- a/solo/methods/rot.py +++ b/solo/methods/rot.py @@ -90,24 +90,6 @@ def forward(self, X: torch.Tensor, *args, **kwargs) -> Dict[str, Any]: """ return self.base_forward(*args, **kwargs) - - def base_forward(self, X: torch.Tensor) -> Dict: - """Basic forward that allows children classes to override forward(). - - Args: - X (torch.Tensor): batch of images in tensor format. - - Returns: - Dict: dict of logits and features. - """ - - feats = self.backbone(X) - logits = self.classifier(feats) - return { - "logits": logits, - "feats": feats, - } - def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor: """Training step for SimSiam reusing BaseMethod training step. diff --git a/solo/methods/simclr.py b/solo/methods/simclr.py index 4b3c5a7..a297d0d 100644 --- a/solo/methods/simclr.py +++ b/solo/methods/simclr.py @@ -24,6 +24,7 @@ import torch.nn as nn from solo.losses.simclr import simclr_loss_func from solo.methods.base import BaseMethod +from solo.losses.simclr import w_simclr_loss_func class SimCLR(BaseMethod):