Skip to content

Commit

Permalink
add off-line eval (linear)
Browse files Browse the repository at this point in the history
  • Loading branch information
doxawang committed Jan 2, 2022
1 parent 424d665 commit 5d8e3f8
Show file tree
Hide file tree
Showing 13 changed files with 160 additions and 68 deletions.
2 changes: 1 addition & 1 deletion bash_files/linear/cifar10/run.sh
Original file line number Diff line number Diff line change
@@ -1 +1 @@
c0 sh byol.sh cifar10 resnet18 ../../linear/cifar10/datasets ../../pretrain/cifar/trained_models/byol/1m9vggjj/byol-cifar10-1m9vggjj-ep=999.ckpt
c0 sh simclr.sh cifar10 resnet18 ../../linear/cifar10/datasets ../../pretrain/cifar/trained_models/byol/1m9vggjj/byol-cifar10-1m9vggjj-ep=999.ckpt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
python3 ../../../main_eval.py \
python3 ../../../main_linear.py \
--dataset $1 \
--backbone $2 \
--data_dir $3 \
Expand All @@ -13,9 +13,8 @@ python3 ../../../main_eval.py \
--weight_decay 0 \
--batch_size 256 \
--num_workers 10 \
--pretrained_feature_extractor PATH \
--name $1-byol-$2-linear-eval \
--name $1-linear \
--entity doxawang \
--project solo-learn \
--pretrained_feature_extractor $4 \
--wandb
$4 \
--wandb
21 changes: 17 additions & 4 deletions bash_files/pretrain/cifar/sweep_poison_method.sh
Original file line number Diff line number Diff line change
@@ -1,11 +1,24 @@
# i=0
# for model in byol mocov2plus simsiam supcon
# do
# for file in /data/yfwang/solo-learn/poison_datasets/cifar10/zoo-${model}/*.pt
# do
# CUDA_VISIBLE_DEVICES=${i} sh ${model}.sh cifar10 " --poison_data ${file} --use_poison --checkpoint_dir /data/yfwang/solo-learn/pretrain/cifar10 " &
# i=`expr ${i} + 1`
# CUDA_VISIBLE_DEVICES=${i} sh simclr.sh cifar10 " --poison_data ${file} --use_poison --checkpoint_dir /data/yfwang/solo-learn/pretrain/cifar10 " &
# done
# i=`expr ${i} + 1`
# done

i=0
for model in byol mocov2plus simsiam supcon
for trigger_type in checkerboard_full checkerboard_4corner checkerboard_1corner gaussian_noise
do
for file in /data/yfwang/solo-learn/poison_datasets/cifar10/zoo-${model}/*.pt
for file in /data/yfwang/solo-learn/poison_datasets/cifar10/zoo-simclr/cifar10_zoo-simclr_rate_0.50_target_None_trigger_${trigger_type}*.pt
do
CUDA_VISIBLE_DEVICES=${i} sh ${model}.sh cifar10 " --poison_data ${file} --use_poison --checkpoint_dir /data/yfwang/solo-learn/pretrain/cifar10 " &
i=`expr ${i} + 1`
# echo $trigger_type $file
CUDA_VISIBLE_DEVICES=${i} sh simclr.sh cifar10 " --poison_data ${file} --use_poison --checkpoint_dir /data/yfwang/solo-learn/pretrain/cifar10 " &
# i=`expr ${i} + 1`
# CUDA_VISIBLE_DEVICES=${i} sh simclr.sh cifar10 " --poison_data ${file} --use_poison --checkpoint_dir /data/yfwang/solo-learn/pretrain/cifar10 " &
done
i=`expr ${i} + 1`
done
11 changes: 11 additions & 0 deletions bash_files/pretrain/cifar/sweep_trigger.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
i=0
for model in byol mocov2plus simsiam supcon
do
for file in /data/yfwang/solo-learn/poison_datasets/cifar10/zoo-${model}/*.pt
do
CUDA_VISIBLE_DEVICES=${i} sh ${model}.sh cifar10 " --poison_data ${file} --use_poison --checkpoint_dir /data/yfwang/solo-learn/pretrain/cifar10 " &
i=`expr ${i} + 1`
CUDA_VISIBLE_DEVICES=${i} sh simclr.sh cifar10 " --poison_data ${file} --use_poison --checkpoint_dir /data/yfwang/solo-learn/pretrain/cifar10 " &
done
i=`expr ${i} + 1`
done
28 changes: 25 additions & 3 deletions main_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,21 @@ def main():
if "swin" in args.backbone and cifar:
kwargs["window_size"] = 4

# specify poison args
# args.checkpoint_dir = os.path.join(args.checkpoint_dir, args.method)

if args.use_poison or args.eval_poison:
assert args.poison_data is not None
poison_data = torch.load(args.poison_data)
prefix = '_poison_' if args.use_poison else '_eval_'
poison_suffix = prefix + poison_data['args'].poison_data_name
print('poison data loaded from', args.poison_data)
args.target_class = poison_data['anchor_label']
else:
poison_data = None
poison_suffix = ''

# load model
backbone = backbone_model(**kwargs)
if "resnet" in args.backbone:
# remove fc layer
Expand Down Expand Up @@ -115,21 +130,25 @@ def main():
del args.backbone
model = Class(backbone, **args.__dict__)

train_loader, val_loader = prepare_data(
train_loader, val_loader, poison_val_loader = prepare_data(
args.dataset,
data_dir=args.data_dir,
train_dir=args.train_dir,
val_dir=args.val_dir,
batch_size=args.batch_size,
num_workers=args.num_workers,
use_poison=args.use_poison,
eval_poison=args.eval_poison,
poison_data=poison_data,
)

callbacks = []

# wandb logging
if args.wandb:
wandb_logger = WandbLogger(
name=args.name, project=args.project, entity=args.entity, offline=args.offline
name=args.name + poison_suffix,
project=args.project, entity=args.entity, offline=args.offline
)
wandb_logger.watch(model, log="gradients", log_freq=100)
wandb_logger.log_hyperparams(args)
Expand Down Expand Up @@ -164,7 +183,10 @@ def main():
if args.dali:
trainer.fit(model, val_dataloaders=val_loader, ckpt_path=ckpt_path)
else:
trainer.fit(model, train_loader, val_loader, ckpt_path=ckpt_path)
if args.eval_poison:
trainer.fit(model, train_loader, [val_loader, poison_val_loader], ckpt_path=ckpt_path)
else:
trainer.fit(model, train_loader, val_loader, ckpt_path=ckpt_path)


if __name__ == "__main__":
Expand Down
1 change: 0 additions & 1 deletion main_poison.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@

from solo.utils.classification_dataloader import prepare_data_no_aug
from poisoning_utils import *
import matplotlib.pyplot as plt

def main():
args = parse_args_linear()
Expand Down
14 changes: 3 additions & 11 deletions main_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,23 +123,15 @@ def main():
elif args.dataset in ["imagenet100", "imagenet"] and args.val_dir is None:
val_loader = None
else:
_, val_loader = prepare_data_classification(
_, val_loader, poison_val_loader = prepare_data_classification(
args.dataset,
data_dir=args.data_dir,
train_dir=args.train_dir,
val_dir=args.val_dir,
batch_size=args.batch_size,
num_workers=args.num_workers,
)
if args.eval_poison:
_, poison_val_loader = prepare_data_classification(
args.dataset,
data_dir=args.data_dir,
train_dir=args.train_dir,
val_dir=args.val_dir,
batch_size=args.batch_size,
num_workers=args.num_workers,
poison_data = poison_data,
eval_poison=args.eval_poison,
poison_data=poison_data
)

callbacks = []
Expand Down
2 changes: 1 addition & 1 deletion poisoning_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def generate_trigger(trigger_type='checkerboard_center'):
pattern[15 + h, 15 + w, 0] = trigger_value[h+1][w+1]
mask[15 + h, 15 + w, 0] = 1
elif trigger_type == 'checkerboard_full': # checkerboard at the center
pattern = np.array(Image.open('./data/checkboard.png'))
pattern = np.array(Image.open('./data/checkboard.jpg'))
mask = np.ones(shape=(32, 32, 1), dtype=np.uint8)
elif trigger_type == 'gaussian_noise':
pattern = np.array(Image.open('./data/cifar_gaussian_noise.png'))
Expand Down
3 changes: 3 additions & 0 deletions solo/args/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ def parse_args_linear() -> argparse.Namespace:
args = parser.parse_args()
additional_setup_linear(args)

if args.use_poison:
args.eval_poison = True

return args


Expand Down
2 changes: 1 addition & 1 deletion solo/methods/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def __init__(
knn_eval: bool = False,
knn_k: int = 20,
eval_poison: bool = False,
target_class: int = 0,
target_class: int = None,
**kwargs,
):
"""Base model that implements all basic operations for all self-supervised methods.
Expand Down
70 changes: 53 additions & 17 deletions solo/methods/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
from solo.methods.base import BaseMethod
from solo.utils.lars import LARSWrapper
from solo.utils.metrics import accuracy_at_k, weighted_mean
from solo.utils.metrics import accuracy_at_k, weighted_mean, false_positive, weighted_sum
from torch.optim.lr_scheduler import (
CosineAnnealingLR,
ExponentialLR,
Expand All @@ -51,6 +51,8 @@ def __init__(
extra_optimizer_args: dict,
scheduler: str,
lr_decay_steps: Optional[Sequence[int]] = None,
eval_poison: bool = False,
target_class: int = None,
**kwargs,
):
"""Implements linear evaluation.
Expand Down Expand Up @@ -93,6 +95,10 @@ def __init__(
self.scheduler = scheduler
self.lr_decay_steps = lr_decay_steps

# poison
self.eval_poison = eval_poison
self.target_class = target_class

# all the other parameters
self.extra_args = kwargs

Expand Down Expand Up @@ -236,12 +242,17 @@ def shared_step(
X, target = batch
batch_size = X.size(0)

out = self(X)["logits"]
outs = self(X)
logits = outs["logits"]

loss = F.cross_entropy(logits, target)

acc1, acc5 = accuracy_at_k(logits, target, top_k=(1, 5))

loss = F.cross_entropy(out, target)
fp_target, fp_all = false_positive(logits, target, self.target_class)

acc1, acc5 = accuracy_at_k(out, target, top_k=(1, 5))
return batch_size, loss, acc1, acc5
return {"outs": outs, "batch_size": batch_size, "loss": loss, "acc1": acc1, "acc5": acc5,
"fp_target": fp_target, "fp_all": fp_all}

def training_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:
"""Performs the training step for the linear eval.
Expand All @@ -257,13 +268,15 @@ def training_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:
# set backbone to eval mode
self.backbone.eval()

_, loss, acc1, acc5 = self.shared_step(batch, batch_idx)
outs = self.shared_step(batch, batch_idx)

loss, acc1, acc5 = outs["loss"], outs["acc1"], outs["acc5"]

log = {"train_loss": loss, "train_acc1": acc1, "train_acc5": acc5}
self.log_dict(log, on_epoch=True, sync_dist=True)
return loss

def validation_step(self, batch: torch.Tensor, batch_idx: int) -> Dict[str, Any]:
def validation_step(self, batch: torch.Tensor, batch_idx: int, dataloader_idx: int = None) -> Dict[str, Any]:
"""Performs the validation step for the linear eval.
Args:
Expand All @@ -276,13 +289,17 @@ def validation_step(self, batch: torch.Tensor, batch_idx: int) -> Dict[str, Any]
the classification loss and accuracies.
"""

batch_size, loss, acc1, acc5 = self.shared_step(batch, batch_idx)
outs = self.shared_step(batch, batch_idx)

batch_size, loss, acc1, acc5 = outs["batch_size"], outs["loss"], outs["acc1"], outs["acc5"]

results = {
"batch_size": batch_size,
"val_loss": loss,
"val_acc1": acc1,
"val_acc5": acc5,
"batch_size": outs["batch_size"],
"val_loss": outs["loss"],
"val_acc1": outs["acc1"],
"val_acc5": outs["acc5"],
"fp_target": outs["fp_target"],
"fp_all": outs["fp_all"]
}
return results

Expand All @@ -295,9 +312,28 @@ def validation_epoch_end(self, outs: List[Dict[str, Any]]):
outs (List[Dict[str, Any]]): list of outputs of the validation step.
"""

val_loss = weighted_mean(outs, "val_loss", "batch_size")
val_acc1 = weighted_mean(outs, "val_acc1", "batch_size")
val_acc5 = weighted_mean(outs, "val_acc5", "batch_size")
log_outs = []

log = {"val_loss": val_loss, "val_acc1": val_acc1, "val_acc5": val_acc5}
self.log_dict(log, sync_dist=True)
if self.eval_poison:
clean_outs, poison_outs = outs
log_outs.append([clean_outs, 'clean_'])
log_outs.append([poison_outs, 'poison_'])
else:
log_outs.append([outs, 'clean_'])

for outs, prefix in log_outs:
val_loss = weighted_mean(outs, "val_loss", "batch_size")
val_acc1 = weighted_mean(outs, "val_acc1", "batch_size")
val_acc5 = weighted_mean(outs, "val_acc5", "batch_size")
val_fp_target = weighted_sum(outs, "fp_target", "batch_size")
val_fp_all = weighted_sum(outs, "fp_all", "batch_size")
val_nfp = val_fp_target * 1.0 / val_fp_all

log = {prefix+"val_loss": val_loss,
prefix+"val_acc1": val_acc1,
prefix+"val_acc5": val_acc5,
prefix+"fp_target": val_fp_target,
prefix+"nfp": val_nfp,
}

self.log_dict(log, sync_dist=True)
53 changes: 33 additions & 20 deletions solo/utils/classification_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,8 @@ def prepare_data(
batch_size: int = 64,
num_workers: int = 4,
download: bool = True,
use_poison=False,
eval_poison=False,
poison_data = None,
poison_split = 'val',
) -> Tuple[DataLoader, DataLoader]:
Expand Down Expand Up @@ -393,32 +395,43 @@ def prepare_data(
val_dir=val_dir,
download=download,
)


if poison_data is not None:
from poisoning_utils import transform_dataset
if poison_split in ['train', 'all']:
train_dataset = transform_dataset(
dataset,
train_dataset,
poison_data['pattern'],
poison_data['mask'],
poison_data['args'].trigger_alpha
)
if poison_split in ['val', 'all']:
val_dataset = transform_dataset(
dataset,
val_dataset,
poison_data['pattern'],
poison_data['mask'],
poison_data['args'].trigger_alpha
)
from poisoning_utils import transform_dataset

if use_poison:
train_dataset = transform_dataset(
dataset,
train_dataset,
poison_data['pattern'],
poison_data['mask'],
poison_data['args'].trigger_alpha
)

train_loader, val_loader = prepare_dataloaders(
train_dataset,
val_dataset,
batch_size=batch_size,
num_workers=num_workers,
)
return train_loader, val_loader

if eval_poison:
from copy import copy
poison_val_dataset = transform_dataset(
dataset,
copy(val_dataset),
poison_data['pattern'],
poison_data['mask'],
poison_data['args'].trigger_alpha
)

poison_val_loader = DataLoader(
poison_val_dataset,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=True,
shuffle=False,
drop_last=False,
)
return train_loader, val_loader, poison_val_loader
else:
return train_loader, val_loader, None
Loading

0 comments on commit 5d8e3f8

Please sign in to comment.