Skip to content

Commit

Permalink
[ROCm] CK implementation support causal mask (#18943)
Browse files Browse the repository at this point in the history
Use `MaskingSpecialization::MaskOutUpperTriangle` to support causal mask
in ck implementation.
  • Loading branch information
PeixuanZuo authored Feb 2, 2024
1 parent a2eb967 commit 9139bdd
Show file tree
Hide file tree
Showing 7 changed files with 364 additions and 113 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ using MaskingSpecialization = ck::tensor_operation::device::MaskingSpecializatio

using PassThrough = ck::tensor_operation::element_wise::PassThrough;

using ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute; // the interface
using ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute; // the interface
using ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle; // the implementation

static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
Expand Down Expand Up @@ -141,6 +141,35 @@ std::vector<std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<
GetDeviceBatchedGemmSoftmaxGemmPermuteInstances<
F16, ck::Tuple<F16, F16>, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskDisabled>();

template <>
std::vector<std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<
2, 1, 1, 1, 1,
F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>,
PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>
GetDeviceBatchedGemmSoftmaxGemmPermuteInstances<
F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>();

// fp16, biased, non-masked
template <>
std::vector<std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<
2, 1, 1, 1, 1,
F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>,
PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>
GetDeviceBatchedGemmSoftmaxGemmPermuteInstances<
F16, ck::Tuple<F16>, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>();

// fp16, biased, fp16 masked, basically, two bias
template <>
std::vector<std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<
2, 1, 1, 1, 1,
F16, F16, F16, F16, ck::Tuple<F16, F16>, ck::Tuple<>,
PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>
GetDeviceBatchedGemmSoftmaxGemmPermuteInstances<
F16, ck::Tuple<F16, F16>, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>();

} // namespace internal
} // namespace rocm
} // namespace contrib
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,27 @@ GetDeviceBatchedGemmSoftmaxGemmPermuteInstances<
return instances;
}

using NonBiasedNonmaskedCausal = DeviceBatchedGemmSoftmaxGemmPermute<
2, 1, 1, 1, 1,
F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>,
PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>;

template <>
std::vector<std::unique_ptr<NonBiasedNonmaskedCausal>>
GetDeviceBatchedGemmSoftmaxGemmPermuteInstances<
F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>() {
std::vector<std::unique_ptr<NonBiasedNonmaskedCausal>> instances;
ck::tensor_operation::device::instance::add_device_operation_instances(
instances,
device_batched_gemm_softmax_gemm_permute_instances<
2, 1, 1, 1, 1,
F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp,
MaskingSpecialization::MaskOutUpperTriangle>{});

return instances;
}

} // namespace internal
} // namespace rocm
} // namespace contrib
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,27 @@ GetDeviceBatchedGemmSoftmaxGemmPermuteInstances<
return instances;
}

using BiasedNonmaskedCausal = DeviceBatchedGemmSoftmaxGemmPermute<
2, 1, 1, 1, 1,
F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>,
PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>;

template <>
std::vector<std::unique_ptr<BiasedNonmaskedCausal>>
GetDeviceBatchedGemmSoftmaxGemmPermuteInstances<
F16, ck::Tuple<F16>, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>() {
std::vector<std::unique_ptr<BiasedNonmaskedCausal>> instances;
ck::tensor_operation::device::instance::add_device_operation_instances(
instances,
device_batched_gemm_softmax_gemm_permute_instances<
2, 1, 1, 1, 1,
F16, ck::Tuple<F16>, F32, PreSoftmaxAttentionScoreOp,
MaskingSpecialization::MaskOutUpperTriangle>{});

return instances;
}

} // namespace internal
} // namespace rocm
} // namespace contrib
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,27 @@ GetDeviceBatchedGemmSoftmaxGemmPermuteInstances<
return instances;
}

using BiasedNonmaskedCausal = DeviceBatchedGemmSoftmaxGemmPermute<
2, 1, 1, 1, 1,
F16, F16, F16, F16, ck::Tuple<F16, F16>, ck::Tuple<>,
PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>;

template <>
std::vector<std::unique_ptr<BiasedNonmaskedCausal>>
GetDeviceBatchedGemmSoftmaxGemmPermuteInstances<
F16, ck::Tuple<F16, F16>, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>() {
std::vector<std::unique_ptr<BiasedNonmaskedCausal>> instances;
ck::tensor_operation::device::instance::add_device_operation_instances(
instances,
device_batched_gemm_softmax_gemm_permute_instances<
2, 1, 1, 1, 1,
F16, ck::Tuple<F16, F16>, F32, PreSoftmaxAttentionScoreOp,
MaskingSpecialization::MaskOutUpperTriangle>{});

return instances;
}

} // namespace internal
} // namespace rocm
} // namespace contrib
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -732,122 +732,154 @@ class GemmSoftmaxGemmPermuteTunableOp : public tunable::TunableOp<GemmSoftmaxGem
#ifdef USE_COMPOSABLE_KERNEL
template <typename T, bool USE_BIAS, bool USE_MASK>
auto GetCKGemmSoftmaxGemmPermuteTypeStringAndOps() {
template <typename U, typename V, typename T, bool USE_BIAS, bool USE_MASK>
auto GetArgAndRunInvoker(const U& impl, const V& invoker, const GemmSoftmaxGemmPermuteParams<T>* params) {
constexpr const int kNumBiasBuffer = static_cast<int>(USE_BIAS) + static_cast<int>(USE_MASK);
using Nop = ck::tensor_operation::element_wise::PassThrough;
using Acc0ElementOp = internal::PreSoftmaxAttentionScoreOp;
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
!GemmSoftmaxGemmPermuteTunableOp<T>::IsSupportedMode(params->attention),
"attention mode is not supported, got ", params->attention->mode);
if constexpr (USE_BIAS) {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
params->bias_buffer == nullptr, "biased version only support input with bias");
} else {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
params->bias_buffer != nullptr, "non-biased version only support input without bias");
}
if constexpr (USE_MASK) {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
!GemmSoftmaxGemmPermuteTunableOp<T>::IsSupportedMaskType(params->attention),
"mask type is not supported, got ", params->attention->mask_type);
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
params->mask_index_buffer == nullptr, "masked version only support input with mask");
} else {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
params->mask_index_buffer != nullptr, "non-masked version only support input without mask");
}
auto attn = params->attention;
const int& G0 = attn->batch_size;
const int& G1 = attn->num_heads;
const int& M = attn->sequence_length;
const int& N = attn->total_sequence_length;
const int& K = attn->head_size;
const int& O = attn->v_head_size;
{
auto [m, n, k, o, batch] = params->GetGemmsMNKOBatch();
ORT_ENFORCE(M == m && N == n && K == k && O == o && G0 * G1 == batch, "semantic mismatch");
}
auto [qs, ks, vs] = GetQkvStrides(attn);
std::vector<ck::index_t> q_buffer_lengths = {G0, G1, M, K};
std::vector<ck::index_t> q_buffer_strides = qs.template ForBNSHCoord<std::vector<ck::index_t>>();
std::vector<ck::index_t> k_buffer_lengths = {G0, G1, N, K};
std::vector<ck::index_t> k_buffer_strides = ks.template ForBNSHCoord<std::vector<ck::index_t>>();
std::vector<ck::index_t> v_buffer_lengths = {G0, G1, O, N};
std::vector<ck::index_t> v_buffer_strides = vs.template ForBNHSCoord<std::vector<ck::index_t>>();
std::vector<ck::index_t> out_buffer_lengths = {G0, G1, M, O};
std::vector<ck::index_t> out_buffer_strides = {M * G1 * O, O, G1 * O, 1}; // permute 0213
std::array<void*, kNumBiasBuffer> bias_buffers{};
std::array<std::vector<ck::index_t>, kNumBiasBuffer> bias_lengths{};
std::array<std::vector<ck::index_t>, kNumBiasBuffer> bias_strides{};
if constexpr (USE_BIAS) {
bias_buffers[0] = const_cast<T*>(params->bias_buffer);
bias_lengths[0] = {G0, G1, M, N}; // BN(G0*G1), S(M), T(N)
bias_strides[0] = {G1 * M * N, M * N, N, 1};
}
if constexpr (USE_MASK) {
bias_buffers[kNumBiasBuffer - 1] = params->workspace_buffer;
bias_lengths[kNumBiasBuffer - 1] = {G0, G1, M, N}; // BN(G0*G1), S(M), T(N)
if (params->mask_index_dims.size() == 2) { // [B,T]
bias_strides[kNumBiasBuffer - 1] = {N, 0, 0, 1};
} else if (params->mask_index_dims.size() == 3) { // [B,S,T]
bias_strides[kNumBiasBuffer - 1] = {M * N, 0, N, 1};
} else if (params->mask_index_dims.size() == 4) { // [B,1,max_seq_len,max_seq_len] -->convert--> [B,S,T]
bias_strides[kNumBiasBuffer - 1] = {M * N, 0, N, 1};
} else {
ORT_ENFORCE(false, "Unreachable");
}
}
auto arg = impl->MakeArgumentPointer(
params->q_buffer, params->k_buffer, params->v_buffer, params->out_buffer,
bias_buffers, // Gemm1 bias, as attention mask
{}, // Gemm2 bias
q_buffer_lengths, q_buffer_strides,
k_buffer_lengths, k_buffer_strides,
v_buffer_lengths, v_buffer_strides,
out_buffer_lengths, out_buffer_strides,
bias_lengths, bias_strides,
{},
{},
Nop{},
Nop{},
Acc0ElementOp{params->scale},
Nop{},
Nop{});
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()),
impl->GetTypeString(), " does not support the params");
if constexpr (USE_MASK) {
ORT_RETURN_IF_ERROR(GemmSoftmaxGemmPermuteTunableOp<T>::LaunchConvertToFilledMaskValue(params));
}
invoker->Run(arg.get(), StreamConfig{params->StreamHandle()});
return Status::OK();
}
template <typename T, bool USE_BIAS, bool USE_MASK>
auto GetCKGemmSoftmaxGemmPermuteTypeStringAndOps() {
using CKDataType = typename CKDataTypeAdaptor<T>::type;
using D0DataType = typename ck::detail::tuple_concat<
std::conditional_t<USE_BIAS, ck::Tuple<CKDataType>, ck::Tuple<>>,
std::conditional_t<USE_MASK, ck::Tuple<CKDataType>, ck::Tuple<>>>::type;
constexpr static auto MaskingSpec =
constexpr static auto MaskingSpecMaskDisabled =
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
constexpr static auto MaskingSpecMaskOutUpperTriangle =
ck::tensor_operation::device::MaskingSpecialization::MaskOutUpperTriangle;
std::vector<std::pair<std::string, Op<GemmSoftmaxGemmPermuteParams<T>>>>
ret;
std::vector<std::pair<std::string, Op<GemmSoftmaxGemmPermuteParams<T>>>> ret;
for (auto&& impl : internal::GetDeviceBatchedGemmSoftmaxGemmPermuteInstances<
CKDataType, D0DataType, internal::F32, internal::PreSoftmaxAttentionScoreOp, MaskingSpec>()) {
CKDataType, D0DataType, internal::F32, internal::PreSoftmaxAttentionScoreOp, MaskingSpecMaskDisabled>()) {
auto type_string = impl->GetTypeString();
auto invoker = impl->MakeInvokerPointer();
auto op = [impl = std::move(impl), invoker = std::move(invoker)](
const GemmSoftmaxGemmPermuteParams<T>* params) -> Status {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
!GemmSoftmaxGemmPermuteTunableOp<T>::IsSupportedMode(params->attention),
"attention mode is not supported, got ", params->attention->mode);
if constexpr (USE_BIAS) {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
params->bias_buffer == nullptr, "biased version only support input with bias");
} else {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
params->bias_buffer != nullptr, "non-biased version only support input without bias");
}
if constexpr (USE_MASK) {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
!GemmSoftmaxGemmPermuteTunableOp<T>::IsSupportedMaskType(params->attention),
"mask type is not supported, got ", params->attention->mask_type);
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
params->mask_index_buffer == nullptr, "masked version only support input with mask");
} else {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
params->mask_index_buffer != nullptr, "non-masked version only support input without mask");
}
params->attention->is_unidirectional, "unidirectional attention is not supported with MaskingSpecMaskDisabled");
auto attn = params->attention;
const int& G0 = attn->batch_size;
const int& G1 = attn->num_heads;
const int& M = attn->sequence_length;
const int& N = attn->total_sequence_length;
const int& K = attn->head_size;
const int& O = attn->v_head_size;
{
auto [m, n, k, o, batch] = params->GetGemmsMNKOBatch();
ORT_ENFORCE(M == m && N == n && K == k && O == o && G0 * G1 == batch, "semantic mismatch");
}
return GetArgAndRunInvoker<decltype(impl), decltype(invoker), T, USE_BIAS, USE_MASK>(impl, invoker, params);
};
ret.emplace_back(std::make_pair(std::move(type_string), std::move(op)));
}
auto [qs, ks, vs] = GetQkvStrides(attn);
std::vector<ck::index_t> q_buffer_lengths = {G0, G1, M, K};
std::vector<ck::index_t> q_buffer_strides = qs.template ForBNSHCoord<std::vector<ck::index_t>>();
std::vector<ck::index_t> k_buffer_lengths = {G0, G1, N, K};
std::vector<ck::index_t> k_buffer_strides = ks.template ForBNSHCoord<std::vector<ck::index_t>>();
std::vector<ck::index_t> v_buffer_lengths = {G0, G1, O, N};
std::vector<ck::index_t> v_buffer_strides = vs.template ForBNHSCoord<std::vector<ck::index_t>>();
std::vector<ck::index_t> out_buffer_lengths = {G0, G1, M, O};
std::vector<ck::index_t> out_buffer_strides = {M * G1 * O, O, G1 * O, 1}; // permute 0213
std::array<void*, kNumBiasBuffer> bias_buffers{};
std::array<std::vector<ck::index_t>, kNumBiasBuffer> bias_lengths{};
std::array<std::vector<ck::index_t>, kNumBiasBuffer> bias_strides{};
if constexpr (USE_BIAS) {
bias_buffers[0] = const_cast<T*>(params->bias_buffer);
bias_lengths[0] = {G0, G1, M, N}; // BN(G0*G1), S(M), T(N)
bias_strides[0] = {G1 * M * N, M * N, N, 1};
}
if constexpr (USE_MASK) {
bias_buffers[kNumBiasBuffer - 1] = params->workspace_buffer;
bias_lengths[kNumBiasBuffer - 1] = {G0, G1, M, N}; // BN(G0*G1), S(M), T(N)
if (params->mask_index_dims.size() == 2) { // [B,T]
bias_strides[kNumBiasBuffer - 1] = {N, 0, 0, 1};
} else if (params->mask_index_dims.size() == 3) { // [B,S,T]
bias_strides[kNumBiasBuffer - 1] = {M * N, 0, N, 1};
} else if (params->mask_index_dims.size() == 4) { // [B,1,max_seq_len,max_seq_len] -->convert--> [B,S,T]
bias_strides[kNumBiasBuffer - 1] = {M * N, 0, N, 1};
} else {
ORT_ENFORCE(false, "Unreachable");
}
}
for (auto&& impl : internal::GetDeviceBatchedGemmSoftmaxGemmPermuteInstances<
CKDataType, D0DataType, internal::F32, internal::PreSoftmaxAttentionScoreOp, MaskingSpecMaskOutUpperTriangle>()) {
auto type_string = impl->GetTypeString();
auto arg = impl->MakeArgumentPointer(
params->q_buffer, params->k_buffer, params->v_buffer, params->out_buffer,
bias_buffers, // Gemm1 bias, as attention mask
{}, // Gemm2 bias
q_buffer_lengths, q_buffer_strides,
k_buffer_lengths, k_buffer_strides,
v_buffer_lengths, v_buffer_strides,
out_buffer_lengths, out_buffer_strides,
bias_lengths, bias_strides,
{},
{},
Nop{},
Nop{},
Acc0ElementOp{params->scale},
Nop{},
Nop{});
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()),
impl->GetTypeString(), " does not support the params");
if constexpr (USE_MASK) {
ORT_RETURN_IF_ERROR(GemmSoftmaxGemmPermuteTunableOp<T>::LaunchConvertToFilledMaskValue(params));
}
invoker->Run(arg.get(), StreamConfig{params->StreamHandle()});
return Status::OK();
auto invoker = impl->MakeInvokerPointer();
auto op = [impl = std::move(impl), invoker = std::move(invoker)](
const GemmSoftmaxGemmPermuteParams<T>* params) -> Status {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
!params->attention->is_unidirectional, "bidirectional attention is not supported with MaskingSpecMaskOutUpperTriangle");
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
params->attention->sequence_length != params->attention->total_sequence_length,
"seqence_length != total_seqence_length is not supported with MaskingSpecMaskOutUpperTriangle");
return GetArgAndRunInvoker<decltype(impl), decltype(invoker), T, USE_BIAS, USE_MASK>(impl, invoker, params);
};
ret.emplace_back(std::make_pair(std::move(type_string), std::move(op)));
}
return ret;
}
#endif // USE_COMPOSABLE_KERNEL
Expand Down
Loading

0 comments on commit 9139bdd

Please sign in to comment.