Skip to content

Commit

Permalink
Multimodal Vision Llama - rudimentary support (axolotl-ai-cloud#1940)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Sunny <sunny@Sunnys-MacBook-Air.local>
Co-authored-by: sunny <sunnyliu19981005@gmail.com>
  • Loading branch information
3 people authored Oct 3, 2024
1 parent 8443310 commit e1915f5
Show file tree
Hide file tree
Showing 24 changed files with 799 additions and 119 deletions.
2 changes: 1 addition & 1 deletion docs/input_output.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ ds = load_from_disk(f'last_run_prepared/{directory[0]}/')
hi there!. goodbye farewell</s>
```

We can check that the right tokens are ingored by comparing the labels
We can check that the right tokens are ignored by comparing the labels
to each token:

```python
Expand Down
28 changes: 28 additions & 0 deletions docs/multimodal.qmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# MultiModal / Vision Language Models (BETA)

### Supported Models

- Mllama, i.e. llama with vision models

### Usage

Currently multimodal support is limited and doesn't have full feature parity. To finetune a multimodal Llama w/ LoRA,
you'll need to use the following in YAML in combination with the rest of the required hyperparams.

```yaml
base_model: alpindale/Llama-3.2-11B-Vision-Instruct
processor_type: AutoProcessor
skip_prepare_dataset: true

chat_template: llama3_2_vision
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
remove_unused_columns: false
sample_packing: false

# only finetune the Language model, leave the vision model and vision tower frozen
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
```
63 changes: 63 additions & 0 deletions examples/llama-3-vision/lora-11b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
base_model: alpindale/Llama-3.2-11B-Vision-Instruct
processor_type: AutoProcessor
strict: false

# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false

chat_template: llama3_2_vision
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out

adapter: lora
lora_model_dir:

sequence_len: 8192
pad_to_sequence_len: false

lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: true
fp16:
tf32: true

gradient_checkpointing: true
local_rank:
logging_steps: 1
flash_attention: true
eager_attention:

warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
7 changes: 5 additions & 2 deletions src/axolotl/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
from axolotl.utils.models import load_tokenizer
from axolotl.utils.models import load_processor, load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env
from axolotl.utils.wandb_ import setup_wandb_env_vars
Expand Down Expand Up @@ -430,9 +430,12 @@ def load_datasets(
cli_args: TrainerCliArgs,
) -> TrainDatasetMeta:
tokenizer = load_tokenizer(cfg)
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None

train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset(
cfg, tokenizer
cfg,
tokenizer,
processor=processor,
)

if cli_args.debug or cfg.debug:
Expand Down
20 changes: 18 additions & 2 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,14 @@
log_prediction_callback_factory,
)
from axolotl.utils.callbacks.lisa import lisa_callback_factory
from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.collators import (
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
MambaDataCollator,
V2BatchSamplerDataCollatorForSeq2Seq,
)
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
from axolotl.utils.models import ensure_dtype
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.schedulers import (
Expand Down Expand Up @@ -250,6 +252,10 @@ class AxolotlTrainingMixins:
"help": "workaround to pass an alternate lr scheduler to the HF trainer"
},
)
chat_template: Optional[str] = field(
default=None,
metadata={"help": "Chat template converting chat messages to text"},
)


@dataclass
Expand Down Expand Up @@ -1043,10 +1049,11 @@ class TrainerBuilderBase(abc.ABC):
_model_ref = None
_peft_config = None

def __init__(self, cfg, model, tokenizer):
def __init__(self, cfg, model, tokenizer, processor=None):
self.cfg = cfg
self.model = model
self.tokenizer = tokenizer
self.processor = processor

# in case the model supports tagging, add the axolotl tag.
# This makes sure the tag is correctly pushed even if a user calls
Expand Down Expand Up @@ -1515,6 +1522,10 @@ def build(self, total_num_steps):
)
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
if self.cfg.chat_template:
training_arguments_kwargs["chat_template"] = chat_templates(
self.cfg.chat_template
)

if self.cfg.rl == "orpo":
training_arguments_kwargs["orpo_alpha"] = self.cfg.orpo_alpha
Expand Down Expand Up @@ -1661,7 +1672,12 @@ def build_collator(
else:
collator = BatchSamplerDataCollatorForSeq2Seq
else:
collator = DataCollatorForSeq2Seq
if self.cfg.processor_type and self.processor:
collator = MultiModalChatDataCollator
kwargs["processor"] = self.processor
kwargs["chat_template"] = training_args.chat_template
else:
collator = DataCollatorForSeq2Seq

return collator(
self.tokenizer,
Expand Down
Loading

0 comments on commit e1915f5

Please sign in to comment.