Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Weights and Baises Integration #44

Merged
merged 4 commits into from
Jan 13, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
wandb integration
  • Loading branch information
ayulockin committed Jan 12, 2022
commit 27a02f202ef8504c1a85251fa0d54f62b288a0c4
3 changes: 3 additions & 0 deletions config/sample_ddpm_128.json
Original file line number Diff line number Diff line change
Expand Up @@ -90,5 +90,8 @@
"update_ema_every": 1,
"ema_decay": 0.9999
}
},
"wandb": {
"project": "generation_ffhq_ddpm"
}
}
3 changes: 3 additions & 0 deletions config/sample_sr3_128.json
Original file line number Diff line number Diff line change
Expand Up @@ -89,5 +89,8 @@
"update_ema_every": 1,
"ema_decay": 0.9999
}
},
"wandb": {
"project": "generation_ffhq_sr3"
}
}
3 changes: 3 additions & 0 deletions config/sr_ddpm_16_128.json
Original file line number Diff line number Diff line change
Expand Up @@ -90,5 +90,8 @@
"update_ema_every": 1,
"ema_decay": 0.9999
}
},
"wandb": {
"project": "sr_ffhq"
}
}
9 changes: 6 additions & 3 deletions config/sr_sr3_16_128.json
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"name": "FFHQ",
"mode": "HR", // whether need LR img
"dataroot": "dataset/ffhq_16_128",
"datatype": "lmdb", //lmdb or img, path of img files
"datatype": "img", //lmdb or img, path of img files
"l_resolution": 16, // low resolution need to super_resolution
"r_resolution": 128, // high resolution
"batch_size": 4,
Expand All @@ -28,8 +28,8 @@
"val": {
"name": "CelebaHQ",
"mode": "LRHR",
"dataroot": "dataset/celebahq_16_128",
"datatype": "lmdb", //lmdb or img, path of img files
"dataroot": "dataset/ffhq_16_128",
"datatype": "img", //lmdb or img, path of img files
"l_resolution": 16,
"r_resolution": 128,
"data_len": 50 // data length in validation
Expand Down Expand Up @@ -89,5 +89,8 @@
"update_ema_every": 1,
"ema_decay": 0.9999
}
},
"wandb": {
"project": "sr_ffhq"
}
}
3 changes: 3 additions & 0 deletions config/sr_sr3_64_512.json
Original file line number Diff line number Diff line change
Expand Up @@ -92,5 +92,8 @@
"update_ema_every": 1,
"ema_decay": 0.9999
}
},
"wandb": {
"project": "distributed_high_sr_ffhq"
}
}
8 changes: 8 additions & 0 deletions core/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ def parse(args):
phase = args.phase
opt_path = args.config
gpu_ids = args.gpu_ids
enable_wandb = args.enable_wandb
log_wandb_ckpt = args.log_wandb_ckpt
log_eval = args.log_eval
# remove comments starting with '//'
json_str = ''
with open(opt_path, 'r') as f:
Expand Down Expand Up @@ -72,6 +75,11 @@ def parse(args):
if phase == 'train':
opt['datasets']['val']['data_len'] = 3

# W&B Logging
opt['enable_wandb'] = enable_wandb
opt['log_wandb_ckpt'] = log_wandb_ckpt
opt['log_eval'] = log_eval

return opt


Expand Down
88 changes: 88 additions & 0 deletions core/wandb_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import os

class WandbLogger:
"""
Log using `Weights and Biases`.
"""
def __init__(self, opt):
try:
import wandb
except ImportError:
raise ImportError(
"To use the Weights and Biases Logger please install wandb."
"Run `pip install wandb` to install it."
)

self._wandb = wandb

# Initialize a W&B run
if self._wandb.run is None:
self._wandb.init(
project=opt['wandb']['project'],
config=opt,
dir='./experiments'
)

self.config = self._wandb.config

if self.config['log_eval']:
self.eval_table = self._wandb.Table(columns=['fake_image',
'sr_image',
'hr_image',
'psnr',
'ssim'])

def log_metrics(self, metrics, commit=True):
"""
Log train/validation metrics onto W&B.

metrics: dictionary of metrics to be logged
"""
self._wandb.log(metrics, commit=commit)

def log_image(self, key_name, image_array):
"""
Log image array onto W&B.

key_name: name of the key
image_array: numpy array of image.
"""
self._wandb.log({key_name: self._wandb.Image(image_array)})

def log_checkpoint(self, current_epoch, current_step):
"""
Log the model checkpoint as W&B artifacts

current_epoch: the current epoch
current_step: the current batch step
"""
model_artifact = self._wandb.Artifact(
self._wandb.run.id + "_model", type="model"
)

gen_path = os.path.join(
self.config.path['checkpoint'], 'I{}_E{}_gen.pth'.format(current_step, current_epoch))
opt_path = os.path.join(
self.config.path['checkpoint'], 'I{}_E{}_opt.pth'.format(current_step, current_epoch))

model_artifact.add_file(gen_path)
model_artifact.add_file(opt_path)
self._wandb.log_artifact(model_artifact, aliases=["latest"])

def log_eval_data(self, fake_img, sr_img, hr_img, psnr, ssim):
"""
Add data row-wise to the initialized table.
"""
self.eval_table.add_data(
self._wandb.Image(fake_img),
self._wandb.Image(sr_img),
self._wandb.Image(hr_img),
psnr,
ssim
)

def log_eval_table(self, commit=False):
"""
Log the table
"""
self._wandb.log({'eval_data': self.eval_table}, commit=commit)
59 changes: 55 additions & 4 deletions sr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
import logging
import core.logger as Logger
import core.metrics as Metrics
from core.wandb_logger import WandbLogger
from tensorboardX import SummaryWriter
import os
import numpy as np
import wandb

if __name__ == "__main__":
parser = argparse.ArgumentParser()
Expand All @@ -17,6 +19,9 @@
help='Run either train(training) or val(generation)', default='train')
parser.add_argument('-gpu', '--gpu_ids', type=str, default=None)
parser.add_argument('-debug', '-d', action='store_true')
parser.add_argument('-enable_wandb', action='store_true')
parser.add_argument('-log_wandb_ckpt', action='store_true')
parser.add_argument('-log_eval', action='store_true')

# parse configs
args = parser.parse_args()
Expand All @@ -35,6 +40,16 @@
logger.info(Logger.dict2str(opt))
tb_logger = SummaryWriter(log_dir=opt['path']['tb_logger'])

# Initialize WandbLogger
if opt['enable_wandb']:
wandb_logger = WandbLogger(opt)
wandb.define_metric('validation/val_step')
wandb.define_metric('epoch')
wandb.define_metric("validation/*", step_metric="val_step")
val_step = 0
else:
wandb_logger = None

# dataset
for phase, dataset_opt in opt['datasets'].items():
if phase == 'train' and args.phase != 'val':
Expand Down Expand Up @@ -81,6 +96,9 @@
tb_logger.add_scalar(k, v, current_step)
logger.info(message)

if wandb_logger:
wandb_logger.log_metrics(logs)

# validation
if current_step % opt['train']['val_freq'] == 0:
avg_psnr = 0.0
Expand Down Expand Up @@ -118,6 +136,12 @@
avg_psnr += Metrics.calculate_psnr(
sr_img, hr_img)

if wandb_logger:
wandb_logger.log_image(
f'validation_{idx}',
np.concatenate((fake_img, sr_img, hr_img), axis=1)
)

avg_psnr = avg_psnr / idx
diffusion.set_new_noise_schedule(
opt['model']['beta_schedule']['train'], schedule_phase='train')
Expand All @@ -129,9 +153,23 @@
# tensorboard logger
tb_logger.add_scalar('psnr', avg_psnr, current_step)

if wandb_logger:
wandb_logger.log_metrics({
'validation/val_psnr': avg_psnr,
'validation/val_step': val_step
})
val_step += 1

if current_step % opt['train']['save_checkpoint_freq'] == 0:
logger.info('Saving models and training states.')
diffusion.save_network(current_epoch, current_step)

if wandb_logger and opt['log_wandb_ckpt']:
wandb_logger.log_checkpoint(current_epoch, current_step)

if wandb_logger:
wandb_logger.log_metrics({'epoch': current_epoch-1})

# save model
logger.info('End of training.')
else:
Expand Down Expand Up @@ -175,10 +213,15 @@
fake_img, '{}/{}_{}_inf.png'.format(result_path, current_step, idx))

# generation
avg_psnr += Metrics.calculate_psnr(
Metrics.tensor2img(visuals['SR'][-1]), hr_img)
avg_ssim += Metrics.calculate_ssim(
Metrics.tensor2img(visuals['SR'][-1]), hr_img)
eval_psnr = Metrics.calculate_psnr(Metrics.tensor2img(visuals['SR'][-1]), hr_img)
eval_ssim = Metrics.calculate_ssim(Metrics.tensor2img(visuals['SR'][-1]), hr_img)

avg_psnr += eval_psnr
avg_ssim += eval_ssim

if wandb_logger and opt['log_eval']:
wandb_logger.log_eval_data(fake_img, Metrics.tensor2img(visuals['SR'][-1]), hr_img, eval_psnr, eval_ssim)

avg_psnr = avg_psnr / idx
avg_ssim = avg_ssim / idx

Expand All @@ -188,3 +231,11 @@
logger_val = logging.getLogger('val') # validation logger
logger_val.info('<epoch:{:3d}, iter:{:8,d}> psnr: {:.4e}, ssim:{:.4e}'.format(
current_epoch, current_step, avg_psnr, avg_ssim))

if wandb_logger:
if opt['log_eval']:
wandb_logger.log_eval_table()
wandb_logger.log_metrics({
'PSNR': float(avg_psnr),
'SSIM': float(avg_ssim)
})