Skip to content

Commit

Permalink
add new
Browse files Browse the repository at this point in the history
  • Loading branch information
doxawang committed May 20, 2022
1 parent ffce840 commit 597c96e
Show file tree
Hide file tree
Showing 5 changed files with 149 additions and 19 deletions.
110 changes: 109 additions & 1 deletion main_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,113 @@ def main_draw_poison(args):




def main_backdoor(args):

assert args.backbone in BaseMethod._SUPPORTED_BACKBONES
backbone_model = {
"resnet18": resnet18,
"resnet50": resnet50,
"vit_tiny": vit_tiny,
"vit_small": vit_small,
"vit_base": vit_base,
"vit_large": vit_large,
"swin_tiny": swin_tiny,
"swin_small": swin_small,
"swin_base": swin_base,
"swin_large": swin_large,
}[args.backbone]

# initialize backbone
kwargs = args.backbone_args
cifar = kwargs.pop("cifar", False)
# swin specific
if "swin" in args.backbone and cifar:
kwargs["window_size"] = 4

backbone = backbone_model(**kwargs)
if "resnet" in args.backbone:
# remove fc layer
if args.load_linear:
backbone.fc = nn.Linear(backbone.inplanes, args.num_classes)
else:
backbone.fc = nn.Identity()
if cifar:
backbone.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=2, bias=False)
backbone.maxpool = nn.Identity()

assert (
args.pretrained_feature_extractor.endswith(".ckpt")
or args.pretrained_feature_extractor.endswith(".pth")
or args.pretrained_feature_extractor.endswith(".pt")
)
ckpt_path = args.pretrained_feature_extractor

state = torch.load(ckpt_path)["state_dict"]

for k in list(state.keys()):
if "encoder" in k:
raise Exception(
"You are using an older checkpoint."
"Either use a new one, or convert it by replacing"
"all 'encoder' occurances in state_dict with 'backbone'"
)
if "backbone" in k:
state[k.replace("backbone.", "")] = state[k]
if args.load_linear:
if "classifier" in k:
state[k.replace("classifier.", "fc.")] = state[k]
del state[k]
# prepare model
backbone.load_state_dict(state, strict=False)
backbone = backbone.cuda()
backbone.eval()

pattern = np.array(Image.open('./data/cifar_gaussian_noise.png'))

train_loader, val_loader, train_dataset, val_dataset = prepare_data_no_aug(
args.dataset,
data_dir=args.data_dir,
train_dir=args.train_dir,
val_dir=args.val_dir,
batch_size=args.batch_size,
num_workers=args.num_workers,
)
backbone.eval()

val_features, val_labels = inference(backbone, val_loader)

alpha = args.trigger_alpha
val_dataset.data = ((1-alpha) * val_dataset.data + alpha * pattern).astype(np.uint8)
poison_val_dataset = transform_dataset('cifar', val_dataset, pattern, 1, 0.2)

poison_val_loader = torch.utils.data.DataLoader(
poison_val_dataset,
batch_size=100,
num_workers=1,
pin_memory=False,
shuffle=False,
drop_last=False,
)

poison_val_features = inference(backbone, poison_val_loader)[0]

# train_features = nn.functional.normalize(train_features, dim=1)
# train_images, train_labels = train_dataset.data, np.array(train_dataset.targets)
# device = torch.device('cuda')
total_correct = 0
total_loss = 0.0
with torch.no_grad():
for i, (images, labels) in enumerate(val_loader):
images, labels = images.cuda(), labels.cuda()
output = backbone(images)
total_loss += nn.functional.cross_entropy(output, labels).item()
pred = output.data.max(1)[1]
total_correct += pred.eq(labels.data.view_as(pred)).sum()
loss = total_loss / len(val_loader)
acc = float(total_correct) / len(val_loader.dataset)
print(acc)

if __name__ == "__main__":
args = parse_args_linear()
# if args.pretrain_method == 'clb':
Expand All @@ -432,7 +539,8 @@ def main_draw_poison(args):
# poison_data = main_clb(args)
# else:
# main_tSNE(args)
main_draw_trigger(args)
# main_draw_trigger(args)
main_backdoor(args)
# main_draw_poison(args)

# args = parse_args_linear()
Expand Down
2 changes: 2 additions & 0 deletions solo/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from solo.losses.vibcreg import vibcreg_loss_func
from solo.losses.vicreg import vicreg_loss_func
from solo.losses.wmse import wmse_loss_func
from solo.losses.simclr import w_simclr_loss_func

__all__ = [
"barlow_loss_func",
Expand All @@ -45,4 +46,5 @@
"vibcreg_loss_func",
"vicreg_loss_func",
"wmse_loss_func",
"w_simclr_loss_func",
]
37 changes: 37 additions & 0 deletions solo/losses/simclr.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,40 @@ def simclr_loss_func(
neg = torch.sum(sim * neg_mask, 1)
loss = -(torch.mean(torch.log(pos / (pos + neg))))
return loss


def w_simclr_loss_func(
z: torch.Tensor, indexes: torch.Tensor, temperature: float = 0.1
) -> torch.Tensor:
"""Computes SimCLR's loss given batch of projected features z
from different views, a positive boolean mask of all positives and
a negative boolean mask of all negatives.
Args:
z (torch.Tensor): (N*views) x D Tensor containing projected features from the views.
indexes (torch.Tensor): unique identifiers for each crop (unsupervised)
or targets of each crop (supervised).
Return:
torch.Tensor: SimCLR loss.
"""

z = F.normalize(z, dim=-1)
gathered_z = gather(z)

sim = torch.exp(torch.einsum("if, jf -> ij", z, gathered_z) / temperature)

gathered_indexes = gather(indexes)

indexes = indexes.unsqueeze(0)
gathered_indexes = gathered_indexes.unsqueeze(0)
# positives
pos_mask = indexes.t() == gathered_indexes
pos_mask[:, z.size(0) * get_rank() :].fill_diagonal_(0)
# negatives
neg_mask = indexes.t() != gathered_indexes

pos = torch.sum(sim * pos_mask, 1)
neg = torch.sum(sim * neg_mask, 1)
loss = -(torch.mean(torch.log(pos / (pos + neg))))
return loss
18 changes: 0 additions & 18 deletions solo/methods/rot.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,24 +90,6 @@ def forward(self, X: torch.Tensor, *args, **kwargs) -> Dict[str, Any]:
"""
return self.base_forward(*args, **kwargs)


def base_forward(self, X: torch.Tensor) -> Dict:
"""Basic forward that allows children classes to override forward().
Args:
X (torch.Tensor): batch of images in tensor format.
Returns:
Dict: dict of logits and features.
"""

feats = self.backbone(X)
logits = self.classifier(feats)
return {
"logits": logits,
"feats": feats,
}

def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor:
"""Training step for SimSiam reusing BaseMethod training step.
Expand Down
1 change: 1 addition & 0 deletions solo/methods/simclr.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import torch.nn as nn
from solo.losses.simclr import simclr_loss_func
from solo.methods.base import BaseMethod
from solo.losses.simclr import w_simclr_loss_func


class SimCLR(BaseMethod):
Expand Down

0 comments on commit 597c96e

Please sign in to comment.