Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AArch64] Fix failure with inline asm and svcount #112537

Merged

Conversation

sdesmalen-arm
Copy link
Collaborator

This fixes an issue where the compiler runs into an assertion failure for the following example:

register svcount_t pred asm("pn8") = svptrue_c8();
asm("ld1w { z0.s, z4.s, z8.s, z12.s }, %[pred]/z, [x0]\n"
:
: [pred] "Uph" (pred)
: "memory", "cc");

Here the register constraint that ends up in the LLVM IR is "{pn8}", but the code in TargetRegisterInfo::getRegForInlineAsmConstraint that parses that string, follows a path where it queries a suitable register class for this register (<=> PPRorPNR regclass), for which it then chooses nxv16i1 as a suitable type. These choices individually are correct, but the combined result isn't, because the type should be aarch64svcount.
This then results in issues later on in SelectionDAGBuilder.cpp in CopyToReg because the type of the actual value and the computed type from the constraint don't match.

This PR pre-empts this issue by parsing the predicate explicitly and returning the correct register class.

@llvmbot
Copy link
Member

llvmbot commented Oct 16, 2024

@llvm/pr-subscribers-backend-aarch64

Author: Sander de Smalen (sdesmalen-arm)

Changes

This fixes an issue where the compiler runs into an assertion failure for the following example:

register svcount_t pred asm("pn8") = svptrue_c8();
asm("ld1w { z0.s, z4.s, z8.s, z12.s }, %[pred]/z, [x0]\n"
:
: [pred] "Uph" (pred)
: "memory", "cc");

Here the register constraint that ends up in the LLVM IR is "{pn8}", but the code in TargetRegisterInfo::getRegForInlineAsmConstraint that parses that string, follows a path where it queries a suitable register class for this register (<=> PPRorPNR regclass), for which it then chooses nxv16i1 as a suitable type. These choices individually are correct, but the combined result isn't, because the type should be aarch64svcount.
This then results in issues later on in SelectionDAGBuilder.cpp in CopyToReg because the type of the actual value and the computed type from the constraint don't match.

This PR pre-empts this issue by parsing the predicate explicitly and returning the correct register class.


Full diff: https://github.com/llvm/llvm-project/pull/112537.diff

2 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+32)
  • (modified) llvm/test/CodeGen/AArch64/aarch64-sve-asm.ll (+64)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index ed06d8a5d63013..c80d1fbfe6a95b 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -11804,6 +11804,36 @@ const char *AArch64TargetLowering::LowerXConstraint(EVT ConstraintVT) const {
 
 enum class PredicateConstraint { Uph, Upl, Upa };
 
+// Returns a {Reg, RegisterClass} tuple if the constraint is
+// a specific predicate register.
+//
+// For some constraint like "{pn3}" the default path in
+// TargetLowering::getRegForInlineAsmConstraint() leads it to determine that a
+// suitable register class for this register is "PPRorPNR", after which it
+// determines that nxv16i1 is an appropriate type for the constraint, which is
+// not what we want. The code here pre-empts this by matching the register
+// explicitly.
+static std::optional<std::pair<unsigned, const TargetRegisterClass *>>
+parsePredicateRegAsConstraint(StringRef Constraint) {
+  if (!Constraint.starts_with('{') || !Constraint.ends_with('}') ||
+      Constraint[1] != 'p')
+    return std::nullopt;
+
+  Constraint = Constraint.substr(2, Constraint.size() - 3);
+  bool IsPredicateAsCount = Constraint.starts_with("n");
+  if (IsPredicateAsCount)
+    Constraint = Constraint.drop_front(1);
+
+  unsigned V;
+  if (Constraint.getAsInteger(10, V) || V > 31)
+    return std::nullopt;
+
+  if (IsPredicateAsCount)
+    return std::make_pair(AArch64::PN0 + V, &AArch64::PNRRegClass);
+  else
+    return std::make_pair(AArch64::P0 + V, &AArch64::PPRRegClass);
+}
+
 static std::optional<PredicateConstraint>
 parsePredicateConstraint(StringRef Constraint) {
   return StringSwitch<std::optional<PredicateConstraint>>(Constraint)
@@ -12051,6 +12081,8 @@ AArch64TargetLowering::getRegForInlineAsmConstraint(
       break;
     }
   } else {
+    if (const auto P = parsePredicateRegAsConstraint(Constraint))
+      return *P;
     if (const auto PC = parsePredicateConstraint(Constraint))
       if (const auto *RegClass = getPredicateRegisterClass(*PC, VT))
         return std::make_pair(0U, RegClass);
diff --git a/llvm/test/CodeGen/AArch64/aarch64-sve-asm.ll b/llvm/test/CodeGen/AArch64/aarch64-sve-asm.ll
index 068e194779c153..8c6ae87b092cba 100644
--- a/llvm/test/CodeGen/AArch64/aarch64-sve-asm.ll
+++ b/llvm/test/CodeGen/AArch64/aarch64-sve-asm.ll
@@ -119,3 +119,67 @@ define <vscale x 8 x half> @test_svfadd_f16_Uph_constraint(<vscale x 16 x i1> %P
   %1 = tail call <vscale x 8 x half> asm "fadd $0.h, $1/m, $2.h, $3.h", "=w,@3Uph,w,w"(<vscale x 16 x i1> %Pg, <vscale x 8 x half> %Zn, <vscale x 8 x half> %Zm)
   ret <vscale x 8 x half> %1
 }
+
+define void @explicit_p0(ptr %p) {
+  ; CHECK-LABEL: name: explicit_p0
+  ; CHECK: bb.0 (%ir-block.0):
+  ; CHECK-NEXT:   liveins: $x0
+  ; CHECK-NEXT: {{  $}}
+  ; CHECK-NEXT:   [[COPY:%[0-9]+]]:gpr64 = COPY $x0
+  ; CHECK-NEXT:   [[PTRUE_B:%[0-9]+]]:ppr = PTRUE_B 31, implicit $vg
+  ; CHECK-NEXT:   $p0 = COPY [[PTRUE_B]]
+  ; CHECK-NEXT:   [[COPY1:%[0-9]+]]:gpr64common = COPY [[COPY]]
+  ; CHECK-NEXT:   INLINEASM &"ld4w { z0.s, z1.s, z2.s, z3.s }, $1/z, [$0]", 1 /* sideeffect attdialect */, 2818058 /* regdef:GPR64common */, def %1, 9 /* reguse */, $p0, 2147483657 /* reguse tiedto:$0 */, [[COPY1]](tied-def 3)
+  ; CHECK-NEXT:   RET_ReallyLR
+  %1 = tail call <vscale x 16 x i1> @llvm.aarch64.sve.ptrue.b8(i32 31)
+  %2 = tail call i64 asm sideeffect "ld4w { z0.s, z1.s, z2.s, z3.s }, $1/z, [$0]", "=r,{p0},0"(<vscale x 16 x i1> %1, ptr %p)
+  ret void
+}
+
+define void @explicit_p8_invalid(ptr %p) {
+  ; CHECK-LABEL: name: explicit_p8_invalid
+  ; CHECK: bb.0 (%ir-block.0):
+  ; CHECK-NEXT:   liveins: $x0
+  ; CHECK-NEXT: {{  $}}
+  ; CHECK-NEXT:   [[COPY:%[0-9]+]]:gpr64 = COPY $x0
+  ; CHECK-NEXT:   [[PTRUE_B:%[0-9]+]]:ppr = PTRUE_B 31, implicit $vg
+  ; CHECK-NEXT:   $p8 = COPY [[PTRUE_B]]
+  ; CHECK-NEXT:   [[COPY1:%[0-9]+]]:gpr64common = COPY [[COPY]]
+  ; CHECK-NEXT:   INLINEASM &"ld4w { z0.s, z1.s, z2.s, z3.s }, $1/z, [$0]", 1 /* sideeffect attdialect */, 2818058 /* regdef:GPR64common */, def %1, 9 /* reguse */, $p8, 2147483657 /* reguse tiedto:$0 */, [[COPY1]](tied-def 3)
+  ; CHECK-NEXT:   RET_ReallyLR
+  %1 = tail call <vscale x 16 x i1> @llvm.aarch64.sve.ptrue.b8(i32 31)
+  %2 = tail call i64 asm sideeffect "ld4w { z0.s, z1.s, z2.s, z3.s }, $1/z, [$0]", "=r,{p8},0"(<vscale x 16 x i1> %1, ptr %p)
+  ret void
+}
+
+define void @explicit_pn8(ptr %p) {
+  ; CHECK-LABEL: name: explicit_pn8
+  ; CHECK: bb.0 (%ir-block.0):
+  ; CHECK-NEXT:   liveins: $x0
+  ; CHECK-NEXT: {{  $}}
+  ; CHECK-NEXT:   [[COPY:%[0-9]+]]:gpr64 = COPY $x0
+  ; CHECK-NEXT:   [[PTRUE_C_B:%[0-9]+]]:pnr_p8to15 = PTRUE_C_B implicit $vg
+  ; CHECK-NEXT:   $pn8 = COPY [[PTRUE_C_B]]
+  ; CHECK-NEXT:   [[COPY1:%[0-9]+]]:gpr64common = COPY [[COPY]]
+  ; CHECK-NEXT:   INLINEASM &"ld1w { z0.s, z4.s, z8.s, z12.s }, $1/z, [$0]", 1 /* sideeffect attdialect */, 2818058 /* regdef:GPR64common */, def %1, 9 /* reguse */, $pn8, 2147483657 /* reguse tiedto:$0 */, [[COPY1]](tied-def 3)
+  ; CHECK-NEXT:   RET_ReallyLR
+  %1 = tail call target("aarch64.svcount") @llvm.aarch64.sve.ptrue.c8()
+  %2 = tail call i64 asm sideeffect "ld1w { z0.s, z4.s, z8.s, z12.s }, $1/z, [$0]", "=r,{pn8},0"(target("aarch64.svcount") %1, ptr %p)
+  ret void
+}
+
+define void @explicit_pn0_invalid(ptr %p) {
+  ; CHECK-LABEL: name: explicit_pn0_invalid
+  ; CHECK: bb.0 (%ir-block.0):
+  ; CHECK-NEXT:   liveins: $x0
+  ; CHECK-NEXT: {{  $}}
+  ; CHECK-NEXT:   [[COPY:%[0-9]+]]:gpr64 = COPY $x0
+  ; CHECK-NEXT:   [[PTRUE_C_B:%[0-9]+]]:pnr_p8to15 = PTRUE_C_B implicit $vg
+  ; CHECK-NEXT:   $pn0 = COPY [[PTRUE_C_B]]
+  ; CHECK-NEXT:   [[COPY1:%[0-9]+]]:gpr64common = COPY [[COPY]]
+  ; CHECK-NEXT:   INLINEASM &"ld1w { z0.s, z4.s, z8.s, z12.s }, $1/z, [$0]", 1 /* sideeffect attdialect */, 2818058 /* regdef:GPR64common */, def %1, 9 /* reguse */, $pn0, 2147483657 /* reguse tiedto:$0 */, [[COPY1]](tied-def 3)
+  ; CHECK-NEXT:   RET_ReallyLR
+  %1 = tail call target("aarch64.svcount") @llvm.aarch64.sve.ptrue.c8()
+  %2 = tail call i64 asm sideeffect "ld1w { z0.s, z4.s, z8.s, z12.s }, $1/z, [$0]", "=r,{pn0},0"(target("aarch64.svcount") %1, ptr %p)
+  ret void
+}

Copy link
Collaborator

@paulwalker-arm paulwalker-arm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The real problem here is that we shouldn't have a register class that contains mutually exclusive registers. That said, the PR itself looks sounds and fixes a bug so I'm happy to accept.

I'll investigate the register class side of things separately and revisit this code to see if my assertion is correct.

This fixes an issue where the compiler runs into an assertion
failure for the following example:

  register svcount_t pred asm("pn8") = svptrue_c8();
  asm("ld1w { z0.s, z4.s, z8.s, z12.s }, %[pred]/z, [x0]\n"
    :
    : [pred] "Uph" (pred)
    : "memory", "cc");

Here the register constraint that ends up in the LLVM IR is "{pn8}",
but the code in `TargetRegisterInfo::getRegForInlineAsmConstraint`
that parses that string, follows a path where it queries a
suitable register class for this register (<=> PPRorPNR regclass),
for which it then chooses `nxv16i1` as a suitable type. These
choices individually are correct, but the combined result isn't,
because the type should be `aarch64svcount`.
This then results in issues later on in SelectionDAGBuilder.cpp
in CopyToReg because the type of the actual value and the computed
type from the constraint don't match.

This PR pre-empts this issue by parsing the predicate explicitly
and returning the correct register class.
@sdesmalen-arm sdesmalen-arm force-pushed the fix-inline-asm-svcount-pnr-or-ppr branch from 2eba07b to 9da6284 Compare October 24, 2024 16:23
@sdesmalen-arm sdesmalen-arm merged commit db0e376 into llvm:main Oct 24, 2024
5 of 8 checks passed
@frobtech frobtech mentioned this pull request Oct 25, 2024
NoumanAmir657 pushed a commit to NoumanAmir657/llvm-project that referenced this pull request Nov 4, 2024
This fixes an issue where the compiler runs into an assertion failure
for the following example:

  register svcount_t pred asm("pn8") = svptrue_c8();
  asm("ld1w { z0.s, z4.s, z8.s, z12.s }, %[pred]/z, [x0]\n"
    :
    : [pred] "Uph" (pred)
    : "memory", "cc");

Here the register constraint that ends up in the LLVM IR is "{pn8}", but
the code in `TargetRegisterInfo::getRegForInlineAsmConstraint` that
parses that string, follows a path where it queries a suitable register
class for this register (<=> PPRorPNR regclass), for which it then
chooses `nxv16i1` as a suitable type. These choices individually are
correct, but the combined result isn't, because the type should be
`aarch64svcount`.
This then results in issues later on in SelectionDAGBuilder.cpp in
CopyToReg because the type of the actual value and the computed type
from the constraint don't match.

This PR pre-empts this issue by parsing the predicate explicitly and
returning the correct register class.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants