Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XPU] llama add xpu support #8282

Merged
merged 9 commits into from
Apr 29, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix
  • Loading branch information
dynamicheart committed Apr 22, 2024
commit e9a4b871d6127ba7bfbfc92b43c96219a0731b20
4 changes: 3 additions & 1 deletion paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,13 +413,13 @@
if self.config.use_fused_rms_norm:
if get_env_device() == "npu":
return core.eager._run_custom_op("rms_norm_npu", hidden_states, self.weight, self.variance_epsilon)[0]
elif get_env_device() == "xpu":
try:
import paddle_xpu_nn # noqa: F821

Check warning on line 418 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L416-L418

Added lines #L416 - L418 were not covered by tests

return paddle_xpu_nn.xpu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0]
except ImportError:
pass

Check warning on line 422 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L420-L422

Added lines #L420 - L422 were not covered by tests
return rms_norm_fused(hidden_states, self.weight, self.variance_epsilon)

if paddle.in_dynamic_mode():
Expand Down Expand Up @@ -589,47 +589,47 @@

ColumnParallelLinear = MC2ColumnSeqParallelLinear
RowParallelLinear = MC2RowSeqParallelLinear
elif get_env_device() == "xpu":
try:
from paddle_xpu.layers.nn.sequence_parallel import ( # noqa: F401

Check warning on line 594 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L592-L594

Added lines #L592 - L594 were not covered by tests
XPUColumnSequenceParallelLinear,
XPURowSequenceParallelLinear,
)

ColumnParallelLinear = XPUColumnSequenceParallelLinear
RowParallelLinear = XPURowSequenceParallelLinear
except ImportError:
ColumnParallelLinear = ColumnSequenceParallelLinear
RowParallelLinear = RowSequenceParallelLinear

Check warning on line 603 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L599-L603

Added lines #L599 - L603 were not covered by tests
else:
ColumnParallelLinear = ColumnSequenceParallelLinear
RowParallelLinear = RowSequenceParallelLinear
else:
if get_env_device() == "xpu":
try:
from paddle_xpu.layers.nn import ( # noqa: F401

Check warning on line 610 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L609-L610

Added lines #L609 - L610 were not covered by tests
ColumnParallelLinear as XPUColumnParallelLinear,
)
from paddle_xpu.layers.nn import ( # noqa: F401

Check warning on line 613 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L613

Added line #L613 was not covered by tests
RowParallelLinear as XPURowParallelLinear,
)

ColumnParallelLinear = XPUColumnParallelLinear
RowParallelLinear = XPURowParallelLinear
except ImportError:
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
dynamicheart marked this conversation as resolved.
Show resolved Hide resolved
RowParallelLinear = fleet.meta_parallel.RowParallelLinear

Check warning on line 621 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L617-L621

Added lines #L617 - L621 were not covered by tests
else:
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
RowParallelLinear = fleet.meta_parallel.RowParallelLinear

if get_env_device() == "xpu":
try:
from paddle_xpu.layers.nn import Linear as XPULinear # noqa: F401

Check warning on line 628 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L627-L628

Added lines #L627 - L628 were not covered by tests

Linear = XPULinear
except ImportError:
Linear = nn.Linear

Check warning on line 632 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L630-L632

Added lines #L630 - L632 were not covered by tests
else:
Linear = nn.Linear

Expand Down Expand Up @@ -663,7 +663,7 @@
)
else:
if config.fuse_attention_ffn:
self.gate_up_fused_proj = Linear(self.hidden_size, self.intermediate_size * 2, bias_attr=False)

Check warning on line 666 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L666

Added line #L666 was not covered by tests
else:
self.gate_proj = Linear(self.hidden_size, self.intermediate_size, bias_attr=False)
self.up_proj = Linear(self.hidden_size, self.intermediate_size, bias_attr=False)
Expand All @@ -673,16 +673,18 @@
def forward(self, x):
if self.fuse_attention_ffn:
# FIXME(yangjianbang): use paddle's native swiglu
if get_env_device() == "xpu":
try:
import paddle_xpu_nn # noqa: F821

Check warning on line 678 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L676-L678

Added lines #L676 - L678 were not covered by tests

out = self.gate_up_fused_proj(x)
out = paddle_xpu_nn.xpu_swiglu(out, axis=-1, turn=True)
out = self.down_proj(out)
return out
except ImportError:
pass
gate_out, up_out = paddle.chunk(self.gate_up_fused_proj(x), chunks=2, axis=-1)
out = self.down_proj(F.silu(gate_out) * up_out)
return out

Check warning on line 687 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L680-L687

Added lines #L680 - L687 were not covered by tests

x = swiglu(self.gate_up_fused_proj(x))
else:
Expand Down Expand Up @@ -761,47 +763,47 @@

ColumnParallelLinear = MC2ColumnSeqParallelLinear
RowParallelLinear = MC2RowSeqParallelLinear
elif get_env_device() == "xpu":
try:
from paddle_xpu.layers.nn.sequence_parallel import ( # noqa: F401

Check warning on line 768 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L766-L768

Added lines #L766 - L768 were not covered by tests
XPUColumnSequenceParallelLinear,
XPURowSequenceParallelLinear,
)

ColumnParallelLinear = XPUColumnSequenceParallelLinear
RowParallelLinear = XPURowSequenceParallelLinear
except ImportError:
ColumnParallelLinear = ColumnSequenceParallelLinear
RowParallelLinear = RowSequenceParallelLinear

Check warning on line 777 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L773-L777

Added lines #L773 - L777 were not covered by tests
else:
ColumnParallelLinear = ColumnSequenceParallelLinear
RowParallelLinear = RowSequenceParallelLinear
else:
if get_env_device() == "xpu":
try:
from paddle_xpu.layers.nn import ( # noqa: F401

Check warning on line 784 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L783-L784

Added lines #L783 - L784 were not covered by tests
ColumnParallelLinear as XPUColumnParallelLinear,
)
from paddle_xpu.layers.nn import ( # noqa: F401

Check warning on line 787 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L787

Added line #L787 was not covered by tests
RowParallelLinear as XPURowParallelLinear,
)

ColumnParallelLinear = XPUColumnParallelLinear
RowParallelLinear = XPURowParallelLinear
except ImportError:
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
RowParallelLinear = fleet.meta_parallel.RowParallelLinear

Check warning on line 795 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L791-L795

Added lines #L791 - L795 were not covered by tests
else:
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
RowParallelLinear = fleet.meta_parallel.RowParallelLinear

if get_env_device() == "xpu":
try:
from paddle_xpu.layers.nn import Linear as XPULinear # noqa: F401

Check warning on line 802 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L801-L802

Added lines #L801 - L802 were not covered by tests

Linear = XPULinear
except:
Linear = nn.Linear

Check warning on line 806 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L804-L806

Added lines #L804 - L806 were not covered by tests
else:
Linear = nn.Linear

Expand Down Expand Up @@ -834,12 +836,12 @@
gather_output=False,
)
else:
self.k_proj = Linear(

Check warning on line 839 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L839

Added line #L839 was not covered by tests
self.hidden_size,
self.config.num_key_value_heads * self.head_dim,
bias_attr=False,
)
self.v_proj = Linear(

Check warning on line 844 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L844

Added line #L844 was not covered by tests
self.hidden_size,
self.config.num_key_value_heads * self.head_dim,
bias_attr=False,
Expand All @@ -847,7 +849,7 @@

else:
if self.fuse_attention_qkv:
self.qkv_proj = Linear(

Check warning on line 852 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L852

Added line #L852 was not covered by tests
self.hidden_size,
self.hidden_size + 2 * self.config.num_key_value_heads * self.head_dim,
bias_attr=False,
Expand Down Expand Up @@ -1522,10 +1524,10 @@
expanded_attn_mask = expanded_attn_mask.astype("float16")
expanded_attn_mask = paddle.where(expanded_attn_mask, x, y).astype(dtype)
elif get_env_device() == "xpu":
x = paddle.to_tensor(0.0, dtype=dtype)
y = paddle.to_tensor(paddle.finfo(dtype).min, dtype=dtype)
expanded_attn_mask = expanded_attn_mask.astype(dtype)
expanded_attn_mask = paddle.where(expanded_attn_mask, x, y).astype(dtype)

Check warning on line 1530 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L1527-L1530

Added lines #L1527 - L1530 were not covered by tests
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

当传入的xy是整型scalar类型时,paddle.where 会将其视为int64、形状[1]的tensor,并会进行broadcast_add操作,详见search.py

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里和上面 npu 的逻辑看着差不多,可以复用吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

理论上是可以复用的,但是npu里面写死了dtype是float16,xpu跑的程序是可能是float16,也可能是bfloat16的。我们需要修改npu的模块么?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SylarTiaNII 看一下?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

根据 @wuhuachaocoding 意见,还是分成if elif两个单独的分支

else:
expanded_attn_mask = paddle.where(expanded_attn_mask, 0.0, paddle.finfo(dtype).min).astype(dtype)
return expanded_attn_mask
Expand Down Expand Up @@ -1807,14 +1809,14 @@
if self.weight.is_distributed:
self.weight.split_axis = 1
if get_env_device() == "xpu":
try:
from paddle_xpu.layers.nn import ( # noqa: F401

Check warning on line 1813 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L1812-L1813

Added lines #L1812 - L1813 were not covered by tests
parallel_matmul as xpu_parallel_matmul,
)

self.xpu_parallel_matmul = xpu_parallel_matmul()
except ImportError:
self.xpu_parallel_matmul = None

Check warning on line 1819 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L1817-L1819

Added lines #L1817 - L1819 were not covered by tests

def forward(self, hidden_states, tensor_parallel_output=None):
if self.config.sequence_parallel:
Expand All @@ -1829,7 +1831,7 @@
tensor_parallel_output = self.config.tensor_parallel_output

if get_env_device() == "xpu" and self.xpu_parallel_matmul is not None:
logits = self.xpu_parallel_matmul(

Check warning on line 1834 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L1834

Added line #L1834 was not covered by tests
hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output, training=self.training
)
else:
Expand Down
Loading