Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pydantic 2.x cfg #1239

Merged
merged 14 commits into from
Feb 26, 2024
2 changes: 1 addition & 1 deletion .mypy.ini
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[mypy]

plugins = pydantic.mypy
exclude = venv

[mypy-alpaca_lora_4bit.*]
Expand Down
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ repos:
additional_dependencies:
[
'types-PyYAML',
'pydantic>=2.5.3',
]
- repo: https://github.com/PyCQA/bandit
rev: 1.7.5
Expand Down
8 changes: 1 addition & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,7 @@ is_mistral_derived_model:
is_qwen_derived_model:

# optional overrides to the base model configuration
model_config:
model_config_overrides:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just noticed this, but we would need to deprecate the old name (add valueerror)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can't actually deprecate it with Pydantic because model_config is an internal variable name for pydantic models

# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
rope_scaling:
type: # linear | dynamic
Expand All @@ -560,8 +560,6 @@ bnb_config_kwargs:

# Whether you are training a 4-bit GPTQ quantized model
gptq: true
gptq_groupsize: 128 # group size
gptq_model_v1: false # v1 or v2

# This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
load_in_8bit: true
Expand Down Expand Up @@ -819,10 +817,6 @@ cosine_constant_lr_ratio: # freeze lr at some percentage of the step, e.g. cosin
# For one_cycle optim
lr_div_factor: # Learning rate div factor

# For log_sweep optim
log_sweep_min_lr:
log_sweep_max_lr:

# Specify optimizer
# Valid values are driven by the Transformers OptimizerNames class, see:
# https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/training_args.py#L134
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ tokenizers==0.15.0
bitsandbytes>=0.41.1
accelerate==0.26.1
deepspeed>=0.13.1
pydantic>=2.5.3
addict
fire
PyYAML>=6.0
Expand All @@ -27,7 +28,7 @@ scipy
scikit-learn==1.2.2
pynvml
art
fschat==0.2.34
fschat==0.2.36
gradio==3.50.2
tensorboard

Expand Down
19 changes: 17 additions & 2 deletions src/axolotl/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@
from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
from transformers.utils import is_torch_bf16_gpu_available

from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta
from axolotl.utils.config import (
GPUCapabilities,
normalize_cfg_datasets,
normalize_config,
validate_config,
Expand Down Expand Up @@ -328,7 +330,6 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
# load the config from the yaml file
with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.safe_load(file))
cfg.axolotl_config_path = config
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
# then overwrite the value
cfg_keys = cfg.keys()
Expand All @@ -341,7 +342,21 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
else:
cfg[k] = kwargs[k]

validate_config(cfg)
cfg.axolotl_config_path = config

try:
device_props = torch.cuda.get_device_properties("cuda")
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)
except: # pylint: disable=bare-except # noqa: E722
gpu_version = None

capabilities = GPUCapabilities(
bf16=is_torch_bf16_gpu_available(),
n_gpu=os.environ.get("WORLD_SIZE", 1),
compute_capability=gpu_version,
)

cfg = validate_config(cfg, capabilities=capabilities)

prepare_optim_env(cfg)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,17 @@
import logging
import os
from pathlib import Path
from typing import Optional

import torch
from transformers.utils import is_torch_bf16_gpu_available

from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.config.models.input.v0_4_1 import (
AxolotlConfigWCapabilities,
AxolotlInputConfig,
)
from axolotl.utils.config.models.internals import GPUCapabilities
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model_config

Expand Down Expand Up @@ -191,7 +197,15 @@ def normalize_cfg_datasets(cfg):
cfg.datasets[idx].conversation = "chatml"


def validate_config(cfg):
def validate_config(cfg: DictDefault, capabilities: Optional[GPUCapabilities] = None):
if capabilities:
return DictDefault(
dict(AxolotlConfigWCapabilities(**cfg.to_dict(), capabilities=capabilities))
)
return DictDefault(dict(AxolotlInputConfig(**cfg.to_dict())))


def legacy_validate_config(cfg):
"""
This is a "pre-validation" step that handles the yaml configuration before we have any
information about the model architecture
Expand Down Expand Up @@ -480,9 +494,6 @@ def validate_config(cfg):
if cfg.rope_scaling:
LOG.warning("`rope_scaling` should now be be a key under `model_config`")

if cfg.warmup_steps and cfg.warmup_ratio:
raise ValueError("warmup_steps and warmup_ratio are mutually exclusive")

if cfg.wandb_run_id and not cfg.wandb_name:
cfg.wandb_name = cfg.wandb_run_id

Expand Down
Empty file.
Empty file.
Empty file.
Loading
Loading