Skip to content

Commit

Permalink
dali
Browse files Browse the repository at this point in the history
  • Loading branch information
okotaku committed Jan 5, 2024
1 parent 4204b93 commit ccaf4af
Show file tree
Hide file tree
Showing 5 changed files with 302 additions and 0 deletions.
15 changes: 15 additions & 0 deletions projects/dali/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#### Single GPU

Environment:

- A6000 Single GPU
- nvcr.io/nvidia/pytorch:23.10-py3

Settings:

- 1epoch training.

| Model | total time |
| :------------------------------------------: | :--------: |
| stable_diffusion_xl_pokemon_blip_fast (fp16) | 9 m 47 s |
| stable_diffusion_xl_pokemon_blip_dali (bf16) | 9 m 44 s |
3 changes: 3 additions & 0 deletions projects/dali/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .sdxl_dali_data_preprocessor import SDXLDALIDataPreprocessor

__all__ = ["SDXLDALIDataPreprocessor"]
43 changes: 43 additions & 0 deletions projects/dali/configs/stable_diffusion_xl_pokemon_blip_dali.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
_base_ = [
"../../../configs/_base_/models/stable_diffusion_xl.py",
"../../../configs/_base_/datasets/pokemon_blip_xl.py",
"../../../configs/_base_/schedules/stable_diffusion_xl_50e.py",
"../../../configs/_base_/default_runtime.py",
]

custom_imports = dict(imports=["projects.dali"], allow_failed_imports=False)

model = dict(
gradient_checkpointing=False)

train_dataloader = dict(batch_size=1)

#optim_wrapper = dict(
# dtype="bfloat16",
# accumulative_counts=4)

optim_wrapper = dict(
_delete_=True,
optimizer=dict(
type="Adafactor",
lr=1e-5,
weight_decay=1e-2,
scale_parameter=False,
relative_step=False),
clip_grad=dict(max_norm=1.0),
accumulative_counts=4)

env_cfg = dict(
cudnn_benchmark=True,
)

custom_hooks = [
dict(
type="VisualizationHook",
prompt=["yoda pokemon"] * 4,
height=1024,
width=1024),
dict(type="SDCheckpointHook"),
dict(type="FastNormHook", fuse_main_ln=False, fuse_gn=False),
dict(type="CompileHook", compile_main=True),
]
40 changes: 40 additions & 0 deletions projects/dali/sdxl_dali_data_preprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import torch
from mmengine.model.base_model.data_preprocessor import BaseDataPreprocessor

from diffengine.registry import MODELS


@MODELS.register_module()
class SDXLDALIDataPreprocessor(BaseDataPreprocessor):
"""SDXLDataPreprocessor."""

def forward(
self,
data: dict,
training: bool = False # noqa
) -> dict | list:
"""Preprocesses the data into the model input format.
After the data pre-processing of :meth:`cast_data`, ``forward``
will stack the input tensor list to a batch tensor at the first
dimension.
Args:
----
data (dict): Data returned by dataloader
training (bool): Whether to enable training time augmentation.
Returns:
-------
dict or list: Data in the same format as the model input.
"""
if "result_class_image" in data["inputs"]:
# dreambooth with class image
data["inputs"]["text"] = data["inputs"]["text"] + data["inputs"][
"result_class_image"].pop("text")
data["inputs"]["img"] = torch.cat([data["inputs"]["img"], data["inputs"][
"result_class_image"].pop("img")], dim=0)
data["inputs"]["time_ids"] = torch.cat([data["inputs"]["time_ids"], data[
"inputs"]["result_class_image"].pop("time_ids")], dim=0)

return super().forward(data)
201 changes: 201 additions & 0 deletions projects/dali/train_dali.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
# flake8: noqa: PTH122,PTH119,ISC002,E402,ANN201,D103,D101,PD901,PD011,ANN204,D105,D102,A003
import argparse
import os
import os.path as osp

import torch
from mmengine.config import Config, DictAction
from mmengine.runner import Runner


def parse_args():
parser = argparse.ArgumentParser(description="Train a model")
parser.add_argument("config", help="train config file path")
parser.add_argument("--work-dir", help="the dir to save logs and models")
parser.add_argument(
"--resume", action="store_true", help="Whether to resume checkpoint.")
parser.add_argument(
"--amp",
action="store_true",
default=False,
help="enable automatic-mixed-precision training")
parser.add_argument(
"--cfg-options",
nargs="+",
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
parser.add_argument(
"--launcher",
choices=["none", "pytorch", "slurm", "mpi"],
default="none",
help="job launcher")
# When using PyTorch version >= 2.0.0, the `torch.distributed.launch`
# will pass the `--local-rank` parameter to `tools/train.py` instead
# of `--local_rank`.
parser.add_argument("--local_rank", "--local-rank", type=int, default=0)
args = parser.parse_args()
if "LOCAL_RANK" not in os.environ:
os.environ["LOCAL_RANK"] = str(args.local_rank)

return args


def merge_args(cfg, args):
"""Merge CLI arguments to config."""
cfg.launcher = args.launcher

# work_dir is determined in this priority: CLI > segment in file > filename
if args.work_dir is not None:
# update configs according to CLI args if args.work_dir is not None
cfg.work_dir = args.work_dir
elif cfg.get("work_dir", None) is None:
# use config filename as default work_dir if cfg.work_dir is None
cfg.work_dir = osp.join("./work_dirs",
osp.splitext(osp.basename(args.config))[0])

# enable automatic-mixed-precision training
if args.amp is True:
optim_wrapper = cfg.optim_wrapper.get("type", "OptimWrapper")
assert optim_wrapper in ["OptimWrapper", "AmpOptimWrapper"], \
"`--amp` is not supported custom optimizer wrapper type " \
f"`{optim_wrapper}."
cfg.optim_wrapper.type = "AmpOptimWrapper"
cfg.optim_wrapper.setdefault("loss_scale", "dynamic")

# resume training
if args.resume:
cfg.resume = True
cfg.load_from = None

if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)

return cfg


import pandas as pd
from nvidia.dali import fn, pipeline_def, types
from nvidia.dali.plugin.pytorch import DALIGenericIterator, LastBatchPolicy


@pipeline_def(enable_conditionals=True)
def sd_pipeline(rank, world_size, files):
rng = fn.random.coin_flip(probability=0.5)

img_raw, label = fn.readers.file(
files=files,
labels=list(range(len(files))),
name="Reader", shard_id=rank,
num_shards=world_size, random_shuffle=True)
img = fn.decoders.image(
img_raw, device="mixed", output_type=types.RGB)
img = img.gpu()

sizes = fn.shapes(img)

resized = fn.resize(img, device="gpu", resize_shorter=1024,
interp_type=types.INTERP_LINEAR)
resized = fn.flip(resized, horizontal=rng)
sizes2 = fn.shapes(resized)
output = fn.crop_mirror_normalize(
resized,
dtype=types.FLOAT,
crop=(1024, 1024),
device="gpu",
mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5])
return output, label, sizes, sizes2, rng


class Dummy:
def __init__(self) -> None:
pass

class DaliSDIterator:

def __init__(self) -> None:

if torch.distributed.is_initialized():
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
else:
rank = 0
world_size = 1

df = pd.read_csv("data/pokemon/file.csv")
files = df.file_name.tolist()
self.caption = df.text.values

pipeline = sd_pipeline(
batch_size=1, num_threads=4, device_id=0,
rank=rank, world_size=world_size, files=files)

self.dali_it = DALIGenericIterator(
pipeline,
["jpg", "label", "sizes", "sizes2", "rng"],
dynamic_shape=False,
reader_name="Reader",
auto_reset=True,
prepare_first_batch=False,
last_batch_policy=LastBatchPolicy.DROP)
self.dataset = Dummy()

def __next__(self):
data = self.dali_it.__next__()
crop_top_left = (data[0]["sizes2"][:, :2] - 1024) / 2
time_ids = torch.cat([
data[0]["sizes"][:, :2],
crop_top_left,
data[0]["sizes2"][:, :2],
], dim=1)
return dict(inputs=dict(img=data[0]["jpg"],
text=self.caption[data[0]["label"].reshape(-1)],
time_ids=time_ids))

def next(self):
return self.__next__()

def __iter__(self):
return self

def __len__(self) -> int:
return len(self.caption)


def main() -> None:
args = parse_args()

# load config
cfg = Config.fromfile(args.config)

# merge cli arguments to config
cfg = merge_args(cfg, args)

cfg.model.data_preprocessor = dict(type="SDXLDALIDataPreprocessor")

# build the runner from config
train_loader = DaliSDIterator()
runner = Runner(
model=cfg.model,
train_dataloader=train_loader,
optim_wrapper=cfg.optim_wrapper,
train_cfg=cfg.train_cfg,
launcher=args.launcher,
work_dir=cfg.work_dir,
default_hooks=cfg.default_hooks,
custom_hooks=cfg.custom_hooks,
default_scope=cfg.default_scope,
env_cfg=cfg.env_cfg,
)

# start training
runner.train()


if __name__ == "__main__":
main()

0 comments on commit ccaf4af

Please sign in to comment.