Skip to content

Commit

Permalink
Fix compatibility with transformers 4.36 (AutoGPTQ#483)
Browse files Browse the repository at this point in the history
* compatibility with transformers 4.36

* fix
  • Loading branch information
fxmarty committed Dec 14, 2023
1 parent 1ce453f commit ccb6386
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 14 deletions.
4 changes: 2 additions & 2 deletions auto_gptq/modeling/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,8 +938,8 @@ def skip(*args, **kwargs):
if low_cpu_mem_usage:
make_sure_no_tensor_in_meta_device(model, use_triton, quantize_config.desc_act, quantize_config.group_size, bits=quantize_config.bits)

# Patch until 0.25.0 is released and includes this fix: https://github.com/huggingface/accelerate/pull/2116
if version.parse(accelerate.__version__) < version.parse("0.24.99") or accelerate.__version__ == "0.25.0.dev0":
# Patch until 0.26.0 is released and includes this fix: https://github.com/huggingface/accelerate/pull/2116
if version.parse(accelerate.__version__) < version.parse("0.25.99"):
original_set_module_tensor_to_device = accelerate.utils.modeling.set_module_tensor_to_device
accelerate.utils.modeling.set_module_tensor_to_device = set_module_tensor_to_device_patched

Expand Down
33 changes: 21 additions & 12 deletions auto_gptq/nn_modules/fused_llama_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@ def __init__(
qkv_proj,
o_proj,
rotary_emb,
layer_idx,
):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.layer_idx = layer_idx

if self.head_dim * num_heads != self.hidden_size:
raise ValueError(
Expand Down Expand Up @@ -59,33 +61,36 @@ def forward(

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index. Please open an issue in AutoGPTQ if you hit this."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)

cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
# [bsz, nh, t, hd]

is_causal = past_key_value is None
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)

cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

if use_cache:
# Since qkv_proj is fused, query_states etc will hold a reference to the original qkv_states tensor
# which can cause excessive memory usage by the cache. `contiguous` is a convenient way to workaround this.
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()

past_key_value = (key_states, value_states) if use_cache else None

if compare_pytorch_version("v2.0.0", op="ge"):
attn_output = F.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=None if is_causal else attention_mask,
is_causal=is_causal
attn_mask=attention_mask,
is_causal=attention_mask is None and q_len > 1
)
attn_weights = None
else:
Expand Down Expand Up @@ -187,8 +192,12 @@ def inject_to_model(
qkv_layer.scales = scales
qkv_layer.g_idx = g_idx
qkv_layer.bias = bias

attn = cls(m.hidden_size, m.num_heads, qkv_layer, m.o_proj, m.rotary_emb)

# Introduced in Transformers 4.36
layer_idx = None
if hasattr(m, "layer_idx"):
layer_idx = m.layer_idx
attn = cls(m.hidden_size, m.num_heads, qkv_layer, m.o_proj, m.rotary_emb, layer_idx=layer_idx)

if '.' in name:
parent_name = name.rsplit('.', 1)[0]
Expand Down

0 comments on commit ccb6386

Please sign in to comment.