Skip to content

Commit

Permalink
Brgemm register tiling support for bf16 type (#1005)
Browse files Browse the repository at this point in the history
This PR extends the `brgemm register tiling` pass to support `bf16`
type. The changes:
1) Template the existing pass to execute on `linalg.batch_reduce_matmul`
for `fp32` and `linal.generic` for `vnni` opt bf16,
2) Test-cases for `bf16` type.
  • Loading branch information
arun-thmn authored Feb 19, 2025
1 parent cb1e22f commit f8d8a16
Show file tree
Hide file tree
Showing 7 changed files with 351 additions and 333 deletions.
20 changes: 10 additions & 10 deletions benchmarks/config/base/base.json
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,21 @@
"type": "IR-GEN",
"benchmark": [ "mlir-gen", "--kernel=const --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ],
"environment": {},
"flags": [ "-n", "100", "-run-args='--vector-to-kernels --registerBlocking=8,32 '" ],
"flags": [ "-n", "100", "-run-args='--vector-to-kernels --registerBlocking=8,32,1 '" ],
"extensions": ["avx512.*"]
},
"gemm_fp32_mlir_vector_avx2": {
"type": "IR-GEN",
"benchmark": [ "mlir-gen", "--kernel=const --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ],
"environment": {},
"flags": [ "-n", "100", "-run-args='--vector-to-kernels --registerBlocking=4,16 '" ],
"flags": [ "-n", "100", "-run-args='--vector-to-kernels --registerBlocking=4,16,1 '" ],
"extensions": ["avx2"]
},
"gemm_fp32_mlir_vector_sve": {
"type": "IR-GEN",
"benchmark": [ "mlir-gen", "--kernel=const --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ],
"environment": {},
"flags": [ "-n", "100", "-run-args='--vector-to-kernels --registerBlocking=4,32 '" ],
"flags": [ "-n", "100", "-run-args='--vector-to-kernels --registerBlocking=4,32,1 '" ],
"extensions": ["asimd"]
},
"gemm_bf16_dp2_mlir": {
Expand Down Expand Up @@ -82,21 +82,21 @@
"type": "IR-GEN",
"benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ],
"environment": {},
"flags": [ "-n", "100", "-run-args='--def-parallel --vector-to-kernels --registerBlocking=8,32 '" ],
"flags": [ "-n", "100", "-run-args='--def-parallel --vector-to-kernels --registerBlocking=8,32,1 '" ],
"extensions": ["avx512.*"]
},
"mlp_fp32_mlir_vector_avx2": {
"type": "IR-GEN",
"benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ],
"environment": {},
"flags": [ "-n", "100", "-run-args='--def-parallel --vector-to-kernels --registerBlocking=4,16 '" ],
"flags": [ "-n", "100", "-run-args='--def-parallel --vector-to-kernels --registerBlocking=4,16,1 '" ],
"extensions": ["avx2" ]
},
"mlp_fp32_mlir_vector_sve": {
"type": "IR-GEN",
"benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ],
"environment": {},
"flags": [ "-n", "100", "-run-args='--def-parallel --vector-to-kernels --registerBlocking=4,32 '" ],
"flags": [ "-n", "100", "-run-args='--def-parallel --vector-to-kernels --registerBlocking=4,32,1 '" ],
"extensions": ["asimd"]
},
"mlp_bf16_dp2_mlir": {
Expand Down Expand Up @@ -127,7 +127,7 @@
"type": "IR-GEN",
"benchmark": [ "mlir-gen", "--kernel=const --float-type=f32 --batch=256 --layers=1024,1024,1024,1024" ],
"environment": {},
"flags": [ "-n", "100", "-run-args='--vector-to-kernels --registerBlocking=8,32 '" ],
"flags": [ "-n", "100", "-run-args='--vector-to-kernels --registerBlocking=8,32,1 '" ],
"extensions": [ "avx512.*" ]
},
"fp32_3x1024_args_mlir": {
Expand All @@ -141,7 +141,7 @@
"type": "IR-GEN",
"benchmark": [ "mlir-gen", "--kernel=args --float-type=f32 --batch=256 --layers=1024,1024,1024,1024" ],
"environment": {},
"flags": [ "-n", "100", "-run-args='--vector-to-kernels --registerBlocking=8,32 '" ],
"flags": [ "-n", "100", "-run-args='--vector-to-kernels --registerBlocking=8,32,1 '" ],
"extensions": [ "avx512.*" ]
},
"bf16_3x1024_const_mlir": {
Expand Down Expand Up @@ -172,7 +172,7 @@
"type": "IR-GEN",
"benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-type=f32 --batch=256 --layers=1024,1024,1024,1024" ],
"environment": {},
"flags": [ "-n", "100", "-run-args='--def-parallel --vector-to-kernels --registerBlocking=8,32 '" ],
"flags": [ "-n", "100", "-run-args='--def-parallel --vector-to-kernels --registerBlocking=8,32,1 '" ],
"extensions": [ "avx512.*" ]
},
"fp32_3x1024_args_mlir": {
Expand All @@ -186,7 +186,7 @@
"type": "IR-GEN",
"benchmark": [ "mlir-gen", "--kernel=args --bias --relu --float-type=f32 --batch=256 --layers=1024,1024,1024,1024" ],
"environment": {},
"flags": [ "-n", "100", "-run-args=' --def-parallel --vector-to-kernels --registerBlocking=8,32 '" ],
"flags": [ "-n", "100", "-run-args=' --def-parallel --vector-to-kernels --registerBlocking=8,32,1 '" ],
"extensions": [ "avx512.*" ]
},
"bf16_3x1024_const_mlir": {
Expand Down
Loading

0 comments on commit f8d8a16

Please sign in to comment.