Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

blas_shape: GPU_GEMM: Batch dimension is not collapsible #3439

Open
shivadbhavsar opened this issue Sep 12, 2024 · 3 comments
Open

blas_shape: GPU_GEMM: Batch dimension is not collapsible #3439

shivadbhavsar opened this issue Sep 12, 2024 · 3 comments
Assignees
Labels
bug Something isn't working Torch Benchmarks

Comments

@shivadbhavsar
Copy link
Contributor

Error seen in huggingface torch benchmark: OPTForCasualLM
Only occurs after #3104

Model uncompiled mxr can be found in nas at:
migraphx/models/torch_benchmarks/OPTForCasualLM.mxr

Repro:
migraphx-driver compile OPTForCasualLM.mxr

@shivadbhavsar shivadbhavsar added bug Something isn't working Torch Benchmarks labels Sep 12, 2024
@shivadbhavsar shivadbhavsar self-assigned this Sep 18, 2024
@shivadbhavsar
Copy link
Contributor Author

shivadbhavsar commented Sep 25, 2024

Small repro:

p = migraphx.program()
mm = p.get_main_module()

s1 = migraphx.shape(lens=[4096, 768], type="float_type")
in1 = mm.add_parameter("x", s1)
in1 = mm.add_instruction(migraphx.op("reshape", dims=[2, 2048, 768]), [in1])
in1 = mm.add_instruction(migraphx.op("reshape", dims=[2, -1, 12, 64]), [in1])
in1 = mm.add_instruction(migraphx.op("transpose", permutation=[0, 2, 1, 3]), [in1])
in1 = mm.add_instruction(migraphx.op("contiguous"), [in1])
in1 = mm.add_instruction(migraphx.op("reshape", dims=[24, -1, 64]), [in1])

s2 = migraphx.shape(lens=[2, 12, 2048, 2048], type="float_type")
in2 = mm.add_parameter("x2", s2)

min_lit = mm.add_literal(np.array(-65504, dtype=np.float32))
min_lit = mm.add_instruction(migraphx.op("multibroadcast", out_lens=[2, 12, 2048, 2048]), [min_lit])

max = mm.add_instruction(migraphx.op("max"), [in2, min_lit])
rsp_max = mm.add_instruction(migraphx.op("reshape", dims=[24, 2048, 2048]), [max])
smax = mm.add_instruction(migraphx.op("softmax", axis=-1), [rsp_max])
dot = mm.add_instruction(migraphx.op("dot"), [smax, in1])
dot_rsp = mm.add_instruction(migraphx.op("reshape", dims=[2, 12, 2048, 64]), [dot])

Trace compile:
gmm_err_trace.txt

@shivadbhavsar
Copy link
Contributor Author

Heres when the issue starts:

Pass: fuse_reduce
Pass: dead_code_elimination
x2 = @param:x2 -> float_type, {2, 12, 2048, 2048}, {50331648, 4194304, 2048, 1}
x = @param:x -> float_type, {4096, 768}, {768, 1}
@2 = reshape[dims={2, 2048, 12, 64}](x) -> float_type, {2, 2048, 12, 64}, {1572864, 768, 64, 1}
@3 = transpose[permutation={0, 2, 1, 3}](@2) -> float_type, {2, 12, 2048, 64}, {1572864, 64, 768, 1}
@4 = reshape[dims={24, 2048, 64}](@3) -> float_type, {24, 2048, 64}, {131072, 64, 1}
@5 = pointwise(x2), [main:pointwise0] -> float_type, {2, 12, 2048, 2048}, {50331648, 4194304, 2048, 1}
@6 = reshape[dims={24, 2048, 2048}](@5) -> float_type, {24, 2048, 2048}, {4194304, 2048, 1}
@7 = fused_reduce[axes={2}](@6), [main:reduce_sum1:main:pointwise3:main:reduce_max0:main:pointwise1] -> float_type, {24, 2048, 2048}, {4194304, 2048, 1}
@8 = dot(@7,@4) -> float_type, {24, 2048, 64}, {131072, 64, 1}
@9 = reshape[dims={2, 12, 2048, 64}](@8) -> float_type, {2, 12, 2048, 64}, {1572864, 131072, 64, 1}

Pass: rewrite_reshapes
Pass: simplify_reshapes
x2 = @param:x2 -> float_type, {2, 12, 2048, 2048}, {50331648, 4194304, 2048, 1}
x = @param:x -> float_type, {4096, 768}, {768, 1}
@2 = reshape[dims={2, 2048, 12, 64}](x) -> float_type, {2, 2048, 12, 64}, {1572864, 768, 64, 1}
@3 = transpose[permutation={0, 2, 1, 3}](@2) -> float_type, {2, 12, 2048, 64}, {1572864, 64, 768, 1}
@4 = pointwise(x2), [main:pointwise0] -> float_type, {2, 12, 2048, 2048}, {50331648, 4194304, 2048, 1}
@5 = fused_reduce[axes={3}](@4), [main:reduce_sum1:main:pointwise3:main:reduce_max0:main:pointwise1_reshape] -> float_type, {2, 12, 2048, 2048}, {50331648, 4194304, 2048, 1}
@6 = dot(@5,@3) -> float_type, {2, 12, 2048, 64}, {1572864, 131072, 64, 1}
@7 = identity(@6) -> float_type, {2, 12, 2048, 64}, {1572864, 131072, 64, 1}

After rewrite_reshapes does the simplification, there needs to be a contiguous added. Or there should already have been a contiguous op between the transpose and reshape? (@3 and @4 in fuse_reduce above)

@shivadbhavsar
Copy link
Contributor Author

should be fixed by #3428

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working Torch Benchmarks
Projects
None yet
Development

No branches or pull requests

1 participant