Skip to content

Commit

Permalink
Fix code (modelscope#1824)
Browse files Browse the repository at this point in the history
  • Loading branch information
tastelikefeet committed Aug 27, 2024
1 parent 68d5f6f commit c4cbff9
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 27 deletions.
1 change: 1 addition & 0 deletions swift/llm/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ def prepare_model(model, args: SftArguments):
elif args.sft_type == 'reft':
reft_config = ReftConfig(
model_type=model_type,
layer_key=args.reft_layer_key,
r=args.reft_rank,
layers=args.reft_layers,
intervention_type=args.reft_intervention_type,
Expand Down
1 change: 1 addition & 0 deletions swift/llm/utils/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,7 @@ class SftArguments(ArgumentsBase):
lisa_step_interval: int = 20

# reft
reft_layer_key: Optional[str] = None
reft_layers: Optional[List[int]] = None
reft_rank: int = 4
reft_intervention_type: Literal['NoreftIntervention', 'LoreftIntervention', 'ConsreftIntervention',
Expand Down
19 changes: 13 additions & 6 deletions swift/tuners/llamapro.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,15 @@ def prepare_model(model: nn.Module, config: LLaMAProConfig, adapter_name: str) -
LLaMAPro._set_module_list(config, model, new_module_list)

def state_dict_callback(state_dict, adapter_name):
model_key_mapping = LLaMAPro._get_model_key_mapping(config.model_type, config)
model_key_mapping = LLaMAPro.get_model_key_mapping(config.model_type, config)
new_module_list = [model_key_mapping.module_list + f'.{i}' for i in new_module_idx]
return {
key: value
for key, value in state_dict.items() if any([m_part in key for m_part in new_module_list])
}

def mark_trainable_callback(model):
model_key_mapping = LLaMAPro._get_model_key_mapping(config.model_type, config)
model_key_mapping = LLaMAPro.get_model_key_mapping(config.model_type, config)
new_module_list = [model_key_mapping.module_list + f'.{i}' for i in new_module_idx]
for name, parameter in model.named_parameters():
parameter: nn.Parameter
Expand All @@ -99,7 +99,7 @@ def mark_trainable_callback(model):
@staticmethod
def _update_module_attr(config: LLaMAProConfig, module_list):
model_type = config.model_type
model_key_mapping = LLaMAPro._get_model_key_mapping(model_type, config)
model_key_mapping = LLaMAPro.get_model_key_mapping(model_type, config)
attention = model_key_mapping.attention
attention = attention.split('{}.')[1]
if model_type == 'phi3-small':
Expand Down Expand Up @@ -127,9 +127,16 @@ def _update_module_attr(config: LLaMAProConfig, module_list):
logger.warn(f'model_type: {model_type} seems has no layer_idx, if you encountered anything wrong,'
f'please give us a feedback.')

@classmethod
def get_model_key_mapping(cls, model_type, config) -> ModelKeys:
model_key_mapping = SwiftAdapter.get_model_key_mapping(model_type, config)
assert model_key_mapping.o_proj is not None and model_key_mapping.down_proj is not None, \
'LLaMAPro only support models with o_proj and down_proj components.'
return model_key_mapping

@staticmethod
def _update_module_weight(config: LLaMAProConfig, module_list, new_module_idx):
model_key_mapping = LLaMAPro._get_model_key_mapping(config.model_type, config)
model_key_mapping = LLaMAPro.get_model_key_mapping(config.model_type, config)
o_proj = model_key_mapping.o_proj.split('{}.')[1]
down_proj = model_key_mapping.down_proj.split('{}.')[1]

Expand All @@ -147,14 +154,14 @@ def _update_module_weight(config: LLaMAProConfig, module_list, new_module_idx):

@staticmethod
def _set_module_list(config, module: nn.Module, module_list: nn.ModuleList):
model_key_mapping = LLaMAPro._get_model_key_mapping(config.model_type, config)
model_key_mapping = LLaMAPro.get_model_key_mapping(config.model_type, config)
idx = model_key_mapping.module_list.rfind('.')
parent = module.get_submodule(model_key_mapping.module_list[:idx])
setattr(parent, model_key_mapping.module_list[idx + 1:], module_list)

@staticmethod
def _find_module_list(config, module: nn.Module) -> nn.ModuleList:
model_key_mapping = LLaMAPro._get_model_key_mapping(config.model_type, config)
model_key_mapping = LLaMAPro.get_model_key_mapping(config.model_type, config)
return module.get_submodule(model_key_mapping.module_list)

@staticmethod
Expand Down
16 changes: 12 additions & 4 deletions swift/tuners/reft.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class ReftConfig(SwiftConfig):
Args:
model_type(`Optional[str]`): The model_type to find down_proj/layers.
layer_key(`Optional[str]`): Manually specify the layer key, for example `language_model.layers`.
layers (`Optional[List[int]]`): The layer number to inject.
r(`int`): The rank of Reft.
intervention_type (`Literal['NoreftIntervention', 'LoreftIntervention',
Expand All @@ -31,6 +32,7 @@ class ReftConfig(SwiftConfig):
"""

model_type: Optional[str] = None
layer_key: Optional[str] = None
layers: Optional[List[int]] = None
r: int = 4
intervention_type: Literal['NoreftIntervention', 'LoreftIntervention', 'ConsreftIntervention',
Expand Down Expand Up @@ -109,9 +111,12 @@ def forward2(self, base, source=None, subspaces=None):
NodireftIntervention.forward_origin = NodireftIntervention.forward
NodireftIntervention.forward = forward2

model_key_mapping = Reft._get_model_key_mapping(config.model_type, config)
logger.info(f'Applying Reft to module: {model_key_mapping.module_list}')
module_list: nn.ModuleList = model.get_submodule(model_key_mapping.module_list)
module_list_key = config.layer_key
if module_list_key is None:
model_key_mapping = Reft.get_model_key_mapping(config.model_type, config)
module_list_key = model_key_mapping.module_list
logger.info(f'Applying Reft to module: {module_list_key}')
module_list: nn.ModuleList = model.get_submodule(module_list_key)
representations = []
for idx, layer in enumerate(module_list):
if config.layers and idx not in config.layers:
Expand All @@ -120,7 +125,7 @@ def forward2(self, base, source=None, subspaces=None):
'layer':
idx,
'component':
model_key_mapping.module_list + f'[{idx}].output',
module_list_key + f'[{idx}].output',
'low_rank_dimension':
config.r,
'intervention':
Expand All @@ -137,6 +142,9 @@ def forward2(self, base, source=None, subspaces=None):
def _pre_forward_hook(module, args, kwargs):
if 'base' in kwargs:
return args, kwargs

if 'input_ids' not in kwargs:
raise ValueError('Input does not contain `input_ids`, maybe the model does not support ReFT.')
# run intervened forward pass
unit_locations = None
if 'intervention_locations' in kwargs:
Expand Down
7 changes: 2 additions & 5 deletions swift/tuners/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,8 +331,8 @@ def load(module: torch.nn.Module, adapter_name, module_key):
module.to(module.origin_device)
delattr(module, 'origin_device')

@staticmethod
def _get_model_key_mapping(model_type, config) -> ModelKeys:
@classmethod
def get_model_key_mapping(cls, model_type, config) -> ModelKeys:
if model_type in MODEL_KEYS_MAPPING.keys():
model_key_mapping = MODEL_KEYS_MAPPING[model_type]
else:
Expand All @@ -344,9 +344,6 @@ def _get_model_key_mapping(model_type, config) -> ModelKeys:

if isinstance(model_key_mapping, dict):
model_key_mapping: ModelKeys = ModelKeys(**model_key_mapping)

assert model_key_mapping.o_proj is not None and model_key_mapping.down_proj is not None, \
'LLaMAPro only support models with o_proj and down_proj components.'
return model_key_mapping

@staticmethod
Expand Down
24 changes: 12 additions & 12 deletions swift/utils/module_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,6 @@
from typing import Optional, Union


@dataclasses.dataclass
class MultiModelKeys:

language_model: str = None

projector: Optional[str] = None

vision_tower: str = None

vision_resampler: str = None


@dataclasses.dataclass
class ModelKeys:

Expand Down Expand Up @@ -53,6 +41,18 @@ class ModelKeys:
output: str = None


@dataclasses.dataclass
class MultiModelKeys(ModelKeys):

language_model: str = None

projector: Optional[str] = None

vision_tower: str = None

vision_resampler: str = None


LLAMA_KEYS = ModelKeys(
module_list='model.layers',
mlp='model.layers.{}.mlp',
Expand Down

0 comments on commit c4cbff9

Please sign in to comment.