Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support c4ai-command-r model #684

Merged
merged 21 commits into from
Apr 11, 2024
Prev Previous commit
Next Next commit
fix
  • Loading branch information
jinghan authored and jinghan committed Apr 11, 2024
commit 5349662e1a09f8e02d938fabfff291b85227e5fd
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
CUDA_VISIBLE_DEVICES=0,1,2,3 \
swift infer \
--ckpt_dir "output/c4ai-command-r-plus/vx-xxx/checkpoint-xx" \
--load_dataset_config true \
--load_args_from_ckpt_dir true \
--temperature 0.3 \
--top_p 0.7 \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@ swift sft \
--save_steps 100 \
--save_total_limit 2 \
--logging_steps 10 \
--use_flash_attn true \
--use_flash_attn true \
69 changes: 11 additions & 58 deletions swift/llm/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,55 +455,6 @@ def _register_model(
TemplateType.mengzi,
support_vllm=True,
support_flash_attn=True)
# @register_model(
# ModelType.c4ai_command_r_v01,
# 'AI-ModelScope/c4ai-command-r-v01',
# LoRATM.llama2,
# TemplateType.c4ai,
# requires=['transformers>=4.39.1'],
# support_vllm=True,
# support_flash_attn=True)
# @register_model(
# ModelType.c4ai_command_r_plus,
# 'AI-ModelScope/c4ai-command-r-plus',
# LoRATM.llama2,
# TemplateType.c4ai,
# requires=['transformers>4.39'],
# support_vllm=True,
# support_flash_attn=True)
def get_model_tokenizer_from_repo(model_dir: str,
torch_dtype: Optional[Dtype],
model_kwargs: Dict[str, Any],
load_model: bool = True,
model_config=None,
tokenizer=None,
automodel_class=AutoModelForCausalLM,
**kwargs):
"""load from an independent repository"""
if model_config is None:
model_config = AutoConfig.from_pretrained(
model_dir, trust_remote_code=True)
if torch_dtype is not None:
model_config.torch_dtype = torch_dtype
if tokenizer is None:
tokenizer = AutoTokenizer.from_pretrained(
model_dir, trust_remote_code=True)
eos_token = kwargs.get('eos_token')
if eos_token is not None:
tokenizer.eos_token = eos_token
model = None
context = kwargs.get('context', nullcontext())
if load_model:
with context:
model = automodel_class.from_pretrained(
model_dir,
config=model_config,
torch_dtype=torch_dtype,
trust_remote_code=True,
**model_kwargs)
return model, tokenizer


@register_model(
ModelType.c4ai_command_r_v01,
'AI-ModelScope/c4ai-command-r-v01',
Expand All @@ -520,22 +471,23 @@ def get_model_tokenizer_from_repo(model_dir: str,
requires=['transformers>4.39'],
support_vllm=False,
support_flash_attn=True)
def get_model_tokenizer_c4ai(model_dir: str,
torch_dtype: Optional[Dtype],
model_kwargs: Dict[str, Any],
load_model: bool = True,
model_config=None,
tokenizer=None,
automodel_class=AutoModelForCausalLM,
**kwargs):
def get_model_tokenizer_from_repo(model_dir: str,
torch_dtype: Optional[Dtype],
model_kwargs: Dict[str, Any],
load_model: bool = True,
model_config=None,
tokenizer=None,
automodel_class=AutoModelForCausalLM,
**kwargs):
"""load from an independent repository"""
if model_config is None:
model_config = AutoConfig.from_pretrained(
model_dir, trust_remote_code=True)
if torch_dtype is not None:
model_config.torch_dtype = torch_dtype
if tokenizer is None:
tokenizer = AutoTokenizer.from_pretrained(
model_dir, trust_remote_code=True, use_fast=False)
model_dir, trust_remote_code=True)
eos_token = kwargs.get('eos_token')
if eos_token is not None:
tokenizer.eos_token = eos_token
Expand All @@ -547,6 +499,7 @@ def get_model_tokenizer_c4ai(model_dir: str,
model_dir,
config=model_config,
torch_dtype=torch_dtype,
trust_remote_code=True,
**model_kwargs)
return model, tokenizer

Expand Down
Loading