Skip to content

Commit

Permalink
Phi-3 conversation format, example training script and perplexity met…
Browse files Browse the repository at this point in the history
…ric (axolotl-ai-cloud#1582)

* phi-3 support and perplexity metric

* phi-3 chat template

* metrics updates

* chore: lint

* fix assertion on Tensor

* fix tests since tokenization happens in the metric

* fix perplexity value of shorter passage

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
  • Loading branch information
brianfitzgerald and winglian authored Jun 4, 2024
1 parent c996881 commit cf64284
Show file tree
Hide file tree
Showing 10 changed files with 243 additions and 26 deletions.
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -176,3 +176,9 @@ qlora-out/*
mlruns/*

/.quarto/
prepared-datasets/
submit.sh
*.out*

typings/
out/
64 changes: 64 additions & 0 deletions examples/phi/phi3-ft.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
base_model: microsoft/Phi-3-mini-4k-instruct
trust_remote_code: true
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
chat_template: phi_3

load_in_8bit: false
load_in_4bit: false
strict: false

datasets:
- path: garage-bAInd/Open-Platypus
type: alpaca:phi

dataset_prepared_path:
val_set_size: 0.01
output_dir: ./out

sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true

adapter: lora
lora_model_dir:
lora_r: 64
lora_alpha: 32
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:

gradient_accumulation_steps: 1
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch
adam_beta2: 0.95
adam_epsilon: 0.00001
max_grad_norm: 1.0
lr_scheduler: cosine
learning_rate: 5.0e-6

train_on_inputs: false
group_by_length: false
bf16: auto

gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: True
early_stopping_patience: 3
logging_steps: 1
flash_attention: true

eval_steps: 1000
save_steps: 5000
eval_table_size: 2
eval_batch_size: 2
eval_sample_packing: false
eval_max_new_tokens: 32
eval_causal_lm_metrics: ["perplexity"]
do_causal_lm_eval: true

warmup_ratio: 0.2
debug: true
weight_decay: 0.1
resize_token_embeddings_to_32x: true
15 changes: 11 additions & 4 deletions src/axolotl/prompters.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class PromptStyle(Enum):
INSTRUCT = "instruct"
CHAT = "chat"
CHATML = "chatml"
PHI = "phi"


class Prompter:
Expand All @@ -38,9 +39,9 @@ class AlpacaPrompter(Prompter):
system_format: str = "{system}"
turn_format: str
turn_no_input_format: str
prompt_style: Optional[PromptStyle] = None
prompt_style: Optional[str] = None

def __init__(self, prompt_style=PromptStyle.INSTRUCT.value):
def __init__(self, prompt_style: Optional[str] = PromptStyle.INSTRUCT.value):
self.prompt_style = prompt_style if prompt_style else PromptStyle.INSTRUCT.value
self.match_prompt_style()

Expand All @@ -52,16 +53,20 @@ def match_prompt_style(self):
"### Instruction:\n{instruction}\n\n### Response:\n"
)
self.system_format = "{system}\n\n"
if self.prompt_style == PromptStyle.CHAT.value:
elif self.prompt_style == PromptStyle.CHAT.value:
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
self.system_format = "SYSTEM: {system}\n"
if self.prompt_style == PromptStyle.CHATML.value:
elif self.prompt_style == PromptStyle.CHATML.value:
self.turn_format = "<|im_start|>user\n{instruction}\n{input}<|im_end|>\n<|im_start|>assistant\n"
self.turn_no_input_format = (
"<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n"
)
self.system_format = "<|im_start|>system\n{system}<|im_end|>\n"
elif self.prompt_style == PromptStyle.PHI.value:
self.turn_format = "<|user|>\n{instruction}<|end|>{input}<|assistant|>"
self.turn_no_input_format = "<|user|>\n{instruction}<|end|><|assistant|>"
self.system_format = "<|system|>{system}\n"

def _build_result(self, instruction, input_text, output):
# returns the full prompt from instruction and optional input
Expand Down Expand Up @@ -381,12 +386,14 @@ def __init__(
conversation: Optional[Union[str, Conversation]] = None,
role_key_human: Optional[str] = None,
role_key_model: Optional[str] = None,
role_key_tool: Optional[str] = None,
roles: Optional[dict] = None,
):
super().__init__(
conversation=conversation,
role_key_human=role_key_human,
role_key_model=role_key_model,
role_key_tool=role_key_tool,
roles=roles,
)

Expand Down
34 changes: 24 additions & 10 deletions src/axolotl/utils/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
import math
import os
import traceback
from shutil import copyfile
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Any, Dict, List
Expand All @@ -30,6 +31,7 @@

from axolotl.utils import is_mlflow_available
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.callbacks.perplexity import Perplexity
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
from axolotl.utils.distributed import (
barrier,
Expand Down Expand Up @@ -374,10 +376,14 @@ def __init__(self, cfg):
def __maybe_load_metrics(self):
metrics = {}
for metric in self.cfg.eval_causal_lm_metrics:
try:
metrics[metric] = evaluate.load(metric)
except Exception as exc: # pylint: disable=broad-exception-caught
LOG.warning(f"{metric}: {exc.args}")
if metric == "perplexity":
max_seq_len = self.cfg.eval_max_new_tokens
metrics[metric] = Perplexity(trainer.model, tokenizer, max_seq_len)
else:
try:
metrics[metric] = evaluate.load(metric)
except Exception as exc: # pylint: disable=broad-exception-caught
LOG.warning(f"{metric}: {exc.args}")
return metrics

def on_evaluate(
Expand Down Expand Up @@ -421,13 +427,20 @@ def compute(metric: evaluate.Metric, **kwargs):
# safely compute a metric and return the score if the format is correct
metric_score = None
try:
metric_score = metric.compute(**kwargs)
# Only pass the kwargs that are in the metric's feature list
metric_kwargs = {
k: kwargs[k]
for k in metric._feature_names() # pylint: disable=protected-access
if k in kwargs
}
metric_score = metric.compute(**metric_kwargs)
return (
metric_score["score"]
if "score" in metric_score
else metric_score["mean_score"]
)
except Exception: # pylint: disable=broad-exception-caught
traceback.print_exc()
LOG.debug(
f"Failed to compute metric {metric.name} with kwargs {kwargs.keys()}"
)
Expand All @@ -443,11 +456,12 @@ def evaluate_preds(sources, predictions, references):
predictions=predictions,
sources=sources,
)
score = score or compute(
metric,
references=[[r] for r in references],
predictions=predictions,
)
if score is None:
score = compute(
metric,
references=[[r] for r in references],
predictions=predictions,
)
scores[metric_name] = score
return scores

Expand Down
76 changes: 76 additions & 0 deletions src/axolotl/utils/callbacks/perplexity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""callback to calculate perplexity as an evaluation metric."""
from typing import Dict, List, Optional

import torch
from torch import Tensor
from tqdm import tqdm
from transformers.modeling_outputs import CausalLMOutput
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizer


class Perplexity:
"""
Calculate perplexity as defined in https://huggingface.co/docs/transformers/en/perplexity.
This is a custom variant that doesn't re-tokenize the input or re-load the model.
"""

def __init__(
self,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
max_seq_len: int,
stride: int = 512,
) -> None:
self.max_seq_len = max_seq_len
self.stride = stride
self.model = model
self.tokenizer = tokenizer
self.device = model.device
self.name = "perplexity"

def _feature_names(self) -> List[str]:
return ["references"]

def compute(
self,
references: Optional[List[str]] = None,
) -> Dict[str, float]:
"""
Compute perplexity in a fixed length sliding window across the sequence.
"""
assert references is not None, "Missing parameter: references"

references_tokenized = self.tokenizer(
references, return_tensors="pt", padding=True, truncation=True
)
input_ids: Tensor = references_tokenized["input_ids"] # type: ignore
input_ids = input_ids.to(self.device)

sequence_length = input_ids.size(1)

losses = []
prev_end_loc = 0
for begin_loc in tqdm(range(0, sequence_length, self.stride)):
end_loc = min(begin_loc + self.max_seq_len, sequence_length)
trg_len = end_loc - prev_end_loc
input_ids_slice = input_ids[:, begin_loc:end_loc]
labels_slice = input_ids_slice.clone()
labels_slice[:, :-trg_len] = -100

with torch.no_grad():
outputs: CausalLMOutput = self.model(
input_ids=input_ids_slice, labels=labels_slice
)

losses.append(outputs.loss)

prev_end_loc = end_loc
if end_loc == sequence_length:
break

perplexity = torch.exp(torch.stack(losses).mean()).item()

return {
"score": perplexity,
}
1 change: 1 addition & 0 deletions src/axolotl/utils/chat_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def chat_templates(user_choice: str):
"gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
"cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
"llama3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}",
"phi_3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
}

if user_choice in templates:
Expand Down
6 changes: 3 additions & 3 deletions src/axolotl/utils/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.config.models.input.v0_4_1 import (
SUPPORTED_METRICS,
AxolotlConfigWCapabilities,
AxolotlInputConfig,
)
Expand Down Expand Up @@ -586,13 +587,12 @@ def legacy_validate_config(cfg):
)

if cfg.eval_causal_lm_metrics:
supported_metrics = ["sacrebleu", "comet", "ter", "chrf"]
if not isinstance(cfg.eval_causal_lm_metrics, list):
raise ValueError("eval_causal_lm_metrics must be a list")
# only ["sacrebleu", "comet", "ter", "chrf"] supported
if set(cfg.eval_causal_lm_metrics) - set(supported_metrics):
if set(cfg.eval_causal_lm_metrics) - SUPPORTED_METRICS:
raise ValueError(
f"eval_causal_lm_metrics must be one of {supported_metrics}"
f"eval_causal_lm_metrics must be one of {SUPPORTED_METRICS}"
)

# TODO
Expand Down
8 changes: 5 additions & 3 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

LOG = logging.getLogger("axolotl.utils.config.models.input")

SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"}


class DeprecatedParameters(BaseModel):
"""configurations that are deprecated"""
Expand Down Expand Up @@ -176,6 +178,7 @@ class ChatTemplate(str, Enum):
gemma = "gemma" # pylint: disable=invalid-name
cohere = "cohere" # pylint: disable=invalid-name
llama3 = "llama3" # pylint: disable=invalid-name
phi_3 = "phi_3" # pylint: disable=invalid-name


class LoftQConfig(BaseModel):
Expand Down Expand Up @@ -1073,13 +1076,12 @@ def check_causal_lm_evals(cls, data):
)

if data.get("eval_causal_lm_metrics"):
supported_metrics = ["sacrebleu", "comet", "ter", "chrf"]
if not isinstance(data.get("eval_causal_lm_metrics"), list):
raise ValueError("eval_causal_lm_metrics must be a list")
# only ["sacrebleu", "comet", "ter", "chrf"] supported
if set(data.get("eval_causal_lm_metrics")) - set(supported_metrics):
if set(data.get("eval_causal_lm_metrics")) - SUPPORTED_METRICS:
raise ValueError(
f"eval_causal_lm_metrics must be one of {supported_metrics}"
f"eval_causal_lm_metrics must be one of {SUPPORTED_METRICS}"
)
return data

Expand Down
18 changes: 12 additions & 6 deletions src/axolotl/utils/data/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,12 +474,16 @@ def load_prepare_datasets(
index=cfg.dataset_shard_idx,
)

if split == "train" and cfg.val_set_size:
val_set_size = (
int(cfg.val_set_size) if cfg.val_set_size > 1 else float(cfg.val_set_size)
)

if split == "train" and val_set_size:
# ensure we end up with the same fingerprint by doing rank0 first and being able to cache
to_hash_train = (
dataset._fingerprint # pylint: disable=protected-access
+ "|"
+ str(cfg.val_set_size)
+ str(val_set_size)
+ "|"
+ "train"
+ "|"
Expand All @@ -488,7 +492,7 @@ def load_prepare_datasets(
to_hash_test = (
dataset._fingerprint # pylint: disable=protected-access
+ "|"
+ str(cfg.val_set_size)
+ str(val_set_size)
+ "|"
+ "test"
+ "|"
Expand All @@ -498,9 +502,7 @@ def load_prepare_datasets(
test_fingerprint = md5(to_hash_test)

dataset = dataset.train_test_split(
test_size=int(cfg.val_set_size)
if cfg.val_set_size == int(cfg.val_set_size)
else cfg.val_set_size,
test_size=val_set_size,
shuffle=False,
seed=cfg.seed or 42,
train_new_fingerprint=train_fingerprint,
Expand Down Expand Up @@ -535,6 +537,10 @@ def get_dataset_wrapper(
"keep_in_memory": cfg.dataset_keep_in_memory is True,
}

LOG.info(
f"Loading dataset with base_type: {d_base_type} and prompt_style: {d_prompt_style}"
)

if (
isinstance(dataset, Dataset)
and "input_ids" in dataset.features
Expand Down
Loading

0 comments on commit cf64284

Please sign in to comment.