Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Fix like xtuner #119

Merged
merged 5 commits into from
Jan 5, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Fix like xtuner
  • Loading branch information
okotaku committed Jan 5, 2024
commit 76f47d442b16a44c7a1117a81770078e2c8dfe63
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,6 @@ repos:
exclude: |-
(?x)(
^docs
| ^configs
| ^diffengine/configs
| ^projects
)
1 change: 0 additions & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
include diffengine/.mim/model-index.yml
include add_mim_extension.py
recursive-include diffengine/.mim/configs *.py *.yml
recursive-include diffengine/.mim/tools *.sh *.py
recursive-include diffengine/.mim/demo *.sh *.py
Expand Down
75 changes: 0 additions & 75 deletions add_mim_extension.py

This file was deleted.

3 changes: 2 additions & 1 deletion diffengine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .entry_point import cli
from .version import __version__

__all__ = ["__version__"]
__all__ = ["__version__", "cli"]
23 changes: 23 additions & 0 deletions diffengine/configs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# flake8: noqa: PTH122,PTH120
# Copied from xtuner.configs.__init__
import os


def get_cfgs_name_path():
path = os.path.dirname(__file__)
mapping = {}
for root, _, files in os.walk(path):
# Skip if it is a base config
if "_base_" in root:
continue
for file_ in files:
if file_.endswith(
(".py", ".json"),
) and not file_.startswith(".") and not file_.startswith("_"):
mapping[os.path.splitext(file_)[0]] = os.path.join(root, file_)
return mapping


cfgs_name_path = get_cfgs_name_path()

__all__ = ["cfgs_name_path"]
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
145 changes: 145 additions & 0 deletions diffengine/entry_point.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# Copyright (c) OpenMMLab. All rights reserved.
# flake8: noqa: S603
import logging
import os
import random
import subprocess
import sys

from mmengine.logging import print_log

import diffengine
from diffengine.tools import copy_cfg, list_cfg, train
from diffengine.tools.model_converters import publish_model2diffusers
from diffengine.tools.preprocess import bucket_ids

# Define valid modes
MODES = ("list-cfg", "copy-cfg",
"train", "convert", "preprocess")

CLI_HELP_MSG = \
f"""
Arguments received: {['diffengine'] + sys.argv[1:]!s}. diffengine commands use the following syntax:

diffengine MODE MODE_ARGS ARGS

Where MODE (required) is one of {MODES}
MODE_ARG (optional) is the argument for specific mode
ARGS (optional) are the arguments for specific command

Some usages for diffengine commands: (See more by using -h for specific command!)

1. List all predefined configs:
diffengine list-cfg
2. Copy a predefined config to a given path:
diffengine copy-cfg $CONFIG $SAVE_FILE
3-1. Fine-tune by a single GPU:
diffengine train $CONFIG
3-2. Fine-tune by multiple GPUs:
NPROC_PER_NODE=$NGPUS NNODES=$NNODES NODE_RANK=$NODE_RANK PORT=$PORT ADDR=$ADDR diffengine dist_train $CONFIG $GPUS
4-1. Convert the pth model to HuggingFace's model:
diffengine convert pth_to_hf $CONFIG $PATH_TO_PTH_MODEL $SAVE_PATH_TO_HF_MODEL
5-1. Preprocess bucket ids:
diffengine preprocess bucket_ids

Run special commands:

diffengine help
diffengine version

GitHub: https://github.com/okotaku/diffengine
""" # noqa: E501


PREPROCESS_HELP_MSG = \
f"""
Arguments received: {['diffengine'] + sys.argv[1:]!s}. diffengine commands use the following syntax:

diffengine MODE MODE_ARGS ARGS

Where MODE (required) is one of {MODES}
MODE_ARG (optional) is the argument for specific mode
ARGS (optional) are the arguments for specific command

Some usages for preprocess: (See more by using -h for specific command!)

1. Preprocess arxiv dataset:
diffengine preprocess bucket_ids

GitHub: https://github.com/InternLM/diffengine
""" # noqa: E501


special = {
"help": lambda: print_log(CLI_HELP_MSG, "current"),
"version": lambda: print_log(diffengine.__version__, "current"),
}
special = {
**special,
**{f"-{k[0]}": v
for k, v in special.items()},
**{f"--{k}": v
for k, v in special.items()},
}

modes: dict = {
"list-cfg": list_cfg.__file__,
"copy-cfg": copy_cfg.__file__,
"train": train.__file__,
"convert": publish_model2diffusers.__file__,
"preprocess": {
"bucket_ids": bucket_ids.__file__,
"--help": lambda: print_log(PREPROCESS_HELP_MSG, "current"),
"-h": lambda: print_log(PREPROCESS_HELP_MSG, "current"),
},
}


def cli() -> None:
"""CLI entry point."""
args = sys.argv[1:]
if not args: # no arguments passed
print_log(CLI_HELP_MSG, "current")
return
if args[0].lower() in special:
special[args[0].lower()]()
return
if args[0].lower() in modes:
try:
module = modes[args[0].lower()]
n_arg = 0
while not isinstance(module, str) and not callable(module):
n_arg += 1
module = module[args[n_arg].lower()]
if callable(module):
module()
else:
nnodes = os.environ.get("NNODES", 1)
nproc_per_node = os.environ.get("NPROC_PER_NODE", 1)
if nnodes == 1 and nproc_per_node == 1:
subprocess.run(["python", module] + args[n_arg + 1:], check=True)
else:
port = os.environ.get("PORT", None)
if port is None:
port: int = random.randint(20000, 29999) # type: ignore[no-redef] # noqa
print_log(f"Use random port: {port}", "current",
logging.WARNING)
torchrun_args = [
f"--nnodes={nnodes}",
f"--node_rank={os.environ.get('NODE_RANK', 0)}",
f"--nproc_per_node={nproc_per_node}",
f"--master_addr={os.environ.get('ADDR', '127.0.0.1')}",
f"--master_port={port}",
]
subprocess.run(["torchrun"] + torchrun_args + [module] +
args[n_arg + 1:] +
["--launcher", "pytorch"], check=True)
except Exception as e:
print_log(f"WARNING: command error: '{e}'!", "current",
logging.WARNING)
print_log(CLI_HELP_MSG, "current", logging.WARNING)
return
else:
print_log("WARNING: command error!", "current", logging.WARNING)
print_log(CLI_HELP_MSG, "current", logging.WARNING)
return
36 changes: 36 additions & 0 deletions diffengine/tools/copy_cfg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Copied from xtuner.tools.copy_cfg
import argparse
import os.path as osp
import shutil

from mmengine.utils import mkdir_or_exist

from diffengine.configs import cfgs_name_path


def parse_args(): # noqa
parser = argparse.ArgumentParser()
parser.add_argument("config_name", help="config name")
parser.add_argument("save_dir", help="save directory for copied config")
return parser.parse_args()


def add_copy_suffix(string) -> str:
file_name, ext = osp.splitext(string)
return f"{file_name}_copy{ext}"


def main() -> None:
"""Main function."""
args = parse_args()
mkdir_or_exist(args.save_dir)
config_path = cfgs_name_path[args.config_name]
save_path = osp.join(args.save_dir,
add_copy_suffix(osp.basename(config_path)))
shutil.copyfile(config_path, save_path)
print(f"Copy to {save_path}")


if __name__ == "__main__":
main()
30 changes: 30 additions & 0 deletions diffengine/tools/list_cfg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Copied from xtuner.tools.list_cfg
import argparse

from diffengine.configs import cfgs_name_path


def parse_args(): # noqa
parser = argparse.ArgumentParser()
parser.add_argument(
"-p", "--pattern", default=None, help="Pattern for fuzzy matching")
return parser.parse_args()


def main() -> None:
"""Main function."""
args = parse_args()
configs_names = sorted(cfgs_name_path.keys())
print("==========================CONFIGS===========================")
if args.pattern is not None:
print(f"PATTERN: {args.pattern}")
print("-------------------------------")
for name in configs_names:
if args.pattern is None or args.pattern.lower() in name.lower():
print(name)
print("=============================================================")


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def get_bucket_id(file_name):
joblib.delayed(get_bucket_id)(file_name)
for file_name in tqdm(img_df.file_name.values))

print(pd.DataFrame(bucket_ids).value_counts())

mmengine.dump(bucket_ids, args.out)

if __name__ == "__main__":
Expand Down
21 changes: 21 additions & 0 deletions diffengine/tools/preprocess/csv_to_txt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import argparse

import pandas as pd


def parse_args(): # noqa
parser = argparse.ArgumentParser(
description="Process a checkpoint to be published")
parser.add_argument("input", help="Path to csv")
parser.add_argument("out", help="Path to output txt")
return parser.parse_args()


def main() -> None:
args = parse_args()

img_df = pd.read_csv(args.input)
img_df.to_csv(args.out, header=False, index=False, sep=" ")

if __name__ == "__main__":
main()
10 changes: 10 additions & 0 deletions tools/train.py → diffengine/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from mmengine.registry import RUNNERS
from mmengine.runner import Runner

from diffengine.configs import cfgs_name_path


def parse_args(): # noqa
parser = argparse.ArgumentParser(description="Train a model")
Expand Down Expand Up @@ -80,6 +82,14 @@ def merge_args(cfg, args): # noqa
def main() -> None:
args = parse_args()

# parse config
if not osp.isfile(args.config):
try:
args.config = cfgs_name_path[args.config]
except KeyError as exc:
msg = f"Cannot find {args.config}"
raise FileNotFoundError(msg) from exc

# load config
cfg = Config.fromfile(args.config)

Expand Down
Loading