Skip to content

Commit

Permalink
add activate and deactivate for part tuner (#1470)
Browse files Browse the repository at this point in the history
  • Loading branch information
tastelikefeet committed Jul 23, 2024
1 parent 80d51bc commit adbea0b
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 12 deletions.
2 changes: 1 addition & 1 deletion swift/llm/utils/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ def prepare_ms_hub(self: Union['SftArguments', 'InferArguments']) -> None:
hub_token = self.hub_token
if hub_token is None:
hub_token = os.environ.get('MODELSCOPE_API_TOKEN')
if hub_token is not None:
if hub_token:
api = HubApi()
api.login(hub_token)
if not hasattr(self, 'push_to_hub') or not self.push_to_hub:
Expand Down
7 changes: 5 additions & 2 deletions swift/tuners/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from swift.utils.constants import DEFAULT_ADAPTER, SWIFT_TYPE_KEY
from swift.utils.logger import get_logger
from .. import PeftConfig, PeftModel, get_peft_model
from .utils import SwiftConfig, SwiftOutput
from .utils import SwiftAdapter, SwiftConfig, SwiftOutput

logger = get_logger()

Expand Down Expand Up @@ -102,7 +102,7 @@ def model(self):

def load_state_dict(self, state_dict, strict=True, adapter_name: str = None):
if adapter_name is not None:
output = self.adapters[adapter_name]
output: SwiftOutput = self.adapters[adapter_name]
if getattr(output.config, 'modules_to_save', None):
for key, value in copy(state_dict).items():
for module_name in output.config.modules_to_save:
Expand Down Expand Up @@ -130,6 +130,9 @@ def load_state_dict(self, state_dict, strict=True, adapter_name: str = None):
key = key.replace('lora_embedding_B.', f'lora_embedding_B.{adapter_name}.')
state_dict[key] = value

if output.load_state_dict_callback:
output.load_state_dict_callback(self.base_model, adapter_name, state_dict)

incompatible_keys = self.base_model.load_state_dict(state_dict, False)
if incompatible_keys and len(incompatible_keys[1]) > 0:
logger.error(f'Load state dict with unexpected keys: {incompatible_keys[1]}')
Expand Down
40 changes: 32 additions & 8 deletions swift/tuners/part.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import os
import re
import types
from dataclasses import dataclass, field
from typing import List, Optional, Union
import shutil
from dataclasses import dataclass
from typing import Dict, Optional

import torch
from modelscope.hub.utils.utils import get_cache_dir
from torch import nn

from swift import get_logger
from swift.utils.torch_utils import find_sub_module
from .utils import ActivationMixin, SwiftAdapter, SwiftConfig, SwiftOutput
from .utils import SwiftAdapter, SwiftConfig, SwiftOutput

logger = get_logger()

Expand Down Expand Up @@ -49,8 +49,32 @@ def mark_trainable_callback(model: nn.Module):
if Part.target_module_matched(name, config):
module.requires_grad_(True)

return SwiftOutput(config, state_dict_callback, mark_trainable_callback)
def load_state_dict_callback(module: nn.Module, adapter_name: str, state_dict: Dict[str, torch.Tensor]):
assert adapter_name and '..' not in adapter_name
adapter_keys = state_dict.keys()
original_state_dict = {}
for key, value in module.state_dict().items():
if key in adapter_keys:
original_state_dict[key] = value

setattr(module, f'{adapter_name}.origin', original_state_dict)
setattr(module, f'{adapter_name}.adapter', state_dict)

return SwiftOutput(
config=config,
state_dict_callback=state_dict_callback,
mark_trainable_callback=mark_trainable_callback,
load_state_dict_callback=load_state_dict_callback)

@staticmethod
def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None):
pass
if activate:
state_dict = getattr(module, f'{adapter_name}.adapter', None)
else:
state_dict = getattr(module, f'{adapter_name}.origin', None)
if state_dict:
incompatible_keys = module.load_state_dict(state_dict, False)
if incompatible_keys and len(incompatible_keys[1]) > 0:
logger.error(f'Load state dict with unexpected keys: {incompatible_keys[1]}')
else:
logger.warn('No state_dict found on the module for part tuner.')
6 changes: 6 additions & 0 deletions swift/tuners/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,14 @@ class SwiftOutput:
>>> def mark_trainable_callback(model):
>>> mark_lora_as_trainable(model, config.bias)
optimizer_group_callback (`FunctionType`): A callback returned the param group cared by the tuner.
load_state_dict_callback (`FunctionType`): A callback called before load_state_dict of the tuner.
"""

config: SwiftConfig = None
state_dict_callback: FunctionType = None
mark_trainable_callback: FunctionType = None
optimizer_group_callback: FunctionType = None
load_state_dict_callback: FunctionType = None


class ActivationMixin:
Expand Down Expand Up @@ -318,6 +320,10 @@ def load(module: torch.nn.Module, adapter_name, module_key):
module.to(module.origin_device)
delattr(module, 'origin_device')

@staticmethod
def state_dict_load_hook(model: torch.nn.Module, state_dict: Dict[str, torch.Tensor]):
pass

@staticmethod
def has_additional_modules():
return True
Expand Down
27 changes: 26 additions & 1 deletion tests/tuners/test_swift_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from torch import nn

from swift import AdapterConfig, LoRAConfig, PromptConfig, ResTuningConfig, SideConfig, Swift, SwiftModel
from swift.tuners.part import PartConfig
from swift.tuners.part import Part, PartConfig


class TestSwift(unittest.TestCase):
Expand Down Expand Up @@ -283,6 +283,7 @@ def test_swift_multiple_adapters(self):

def test_part(self):
model = SbertForSequenceClassification(SbertConfig())
model_origin = copy.deepcopy(model)
model2 = copy.deepcopy(model)
targets = r'.*(query|key|value).*'
part_config = PartConfig(target_modules=targets)
Expand All @@ -308,6 +309,30 @@ def target_in(t: str):
self.assertTrue(key in state_dict2)
self.assertTrue(all(torch.isclose(state_dict[key], state_dict2[key]).flatten().detach().cpu()))

self.assertTrue(hasattr(model2, 'part.adapter'))
self.assertTrue(hasattr(model2, 'part.origin'))

model2.activate_adapter('part')
state_dict = model.state_dict()
state_dict2 = model2.state_dict()
for key in state_dict:
self.assertTrue(key in state_dict2)
self.assertTrue(all(torch.isclose(state_dict[key], state_dict2[key]).flatten().detach().cpu()))

model2.deactivate_adapter('part')
state_dict = model_origin.state_dict()
state_dict2 = model2.state_dict()
for key in state_dict2:
self.assertTrue(key in state_dict)
self.assertTrue(all(torch.isclose(state_dict[key], state_dict2[key]).flatten().detach().cpu()))

model2.activate_adapter('part')
state_dict = model.state_dict()
state_dict2 = model2.state_dict()
for key in state_dict:
self.assertTrue(key in state_dict2)
self.assertTrue(all(torch.isclose(state_dict[key], state_dict2[key]).flatten().detach().cpu()))

def test_swift_multiple_adapters_switching(self):
from swift.tuners.lora import Linear
from swift.tuners.adapter import AdapterModule
Expand Down

0 comments on commit adbea0b

Please sign in to comment.