Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DAG] Support saturated truncate #99418

Merged
merged 6 commits into from
Aug 14, 2024
Merged

[DAG] Support saturated truncate #99418

merged 6 commits into from
Aug 14, 2024

Conversation

ParkHanbum
Copy link
Contributor

@ParkHanbum ParkHanbum commented Jul 18, 2024

A truncate is considered saturated if no additional conversion is
required between the target and return values. If the target is
saturated when attempting to truncate from a vector, there is an
opportunity to optimize it.

Previously, each architecture had its own attempt at optimization,
leading to redundant code. This patch implements common logic by
introducing three new ISDs:

ISD::TRUNCATE_SSAT_S: When the operand is a signed value and
the range of values matches the range of signed values of the
destination type.

ISD::TRUNCATE_SSAT_U: When the operand is a signed value and
the range of values matches the range of unsigned values of the
destination type.

ISD::TRUNCATE_USAT_U: When the operand is an unsigned value and
the range of values matches the range of unsigned values of the
destination type.

These ISDs indicate a saturated truncate.

Fixes #85903

@llvmbot
Copy link
Member

llvmbot commented Jul 18, 2024

@llvm/pr-subscribers-backend-risc-v

@llvm/pr-subscribers-backend-aarch64

Author: hanbeom (ParkHanbum)

Changes

truncate is saturated if no additional conversion is required
between the target and return values. if the target is saturated
when attempting to crop from a vector, there is an opportunity
to optimize it.

previously, each architecture had an attemping optimization, so there
was redundant code.

this patch implements common logic by adding ISD::TRUNCATE_[US]SAT
to indicate saturated truncate.


Full diff: https://github.com/llvm/llvm-project/pull/99418.diff

9 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/ISDOpcodes.h (+3)
  • (modified) llvm/include/llvm/Target/TargetSelectionDAG.td (+2)
  • (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+131-1)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp (+2)
  • (modified) llvm/lib/CodeGen/TargetLoweringBase.cpp (+4)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+2)
  • (modified) llvm/lib/Target/AArch64/AArch64InstrInfo.td (+16)
  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+30-10)
  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.h (+2)
diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index e6b10209b4767..0b36e5b40da73 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -804,6 +804,9 @@ enum NodeType {
 
   /// TRUNCATE - Completely drop the high bits.
   TRUNCATE,
+  /// TRUNCATE_[SU]SAT - Truncate for saturated operand
+  TRUNCATE_SSAT,
+  TRUNCATE_USAT,
 
   /// [SU]INT_TO_FP - These operators convert integers (whose interpreted sign
   /// depends on the first letter) to floating point.
diff --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td
index 133c9b113e51b..a5242694c9507 100644
--- a/llvm/include/llvm/Target/TargetSelectionDAG.td
+++ b/llvm/include/llvm/Target/TargetSelectionDAG.td
@@ -471,6 +471,8 @@ def sext       : SDNode<"ISD::SIGN_EXTEND", SDTIntExtendOp>;
 def zext       : SDNode<"ISD::ZERO_EXTEND", SDTIntExtendOp>;
 def anyext     : SDNode<"ISD::ANY_EXTEND" , SDTIntExtendOp>;
 def trunc      : SDNode<"ISD::TRUNCATE"   , SDTIntTruncOp>;
+def truncssat  : SDNode<"ISD::TRUNCATE_SSAT", SDTIntTruncOp>;
+def truncusat  : SDNode<"ISD::TRUNCATE_USAT", SDTIntTruncOp>;
 def bitconvert : SDNode<"ISD::BITCAST"    , SDTUnaryOp>;
 def addrspacecast : SDNode<"ISD::ADDRSPACECAST", SDTUnaryOp>;
 def freeze     : SDNode<"ISD::FREEZE"     , SDTFreeze>;
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 302ad128f4f53..967f313c9885e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -486,6 +486,8 @@ namespace {
     SDValue visitSIGN_EXTEND_INREG(SDNode *N);
     SDValue visitEXTEND_VECTOR_INREG(SDNode *N);
     SDValue visitTRUNCATE(SDNode *N);
+    SDValue visitTRUNCATE_SSAT(SDNode *N);
+    SDValue visitTRUNCATE_USAT(SDNode *N);
     SDValue visitBITCAST(SDNode *N);
     SDValue visitFREEZE(SDNode *N);
     SDValue visitBUILD_PAIR(SDNode *N);
@@ -1907,6 +1909,8 @@ SDValue DAGCombiner::visit(SDNode *N) {
   case ISD::ZERO_EXTEND_VECTOR_INREG:
   case ISD::ANY_EXTEND_VECTOR_INREG: return visitEXTEND_VECTOR_INREG(N);
   case ISD::TRUNCATE:           return visitTRUNCATE(N);
+  case ISD::TRUNCATE_SSAT:      return visitTRUNCATE_SSAT(N);
+  case ISD::TRUNCATE_USAT:      return visitTRUNCATE_USAT(N);
   case ISD::BITCAST:            return visitBITCAST(N);
   case ISD::BUILD_PAIR:         return visitBUILD_PAIR(N);
   case ISD::FADD:               return visitFADD(N);
@@ -13154,7 +13158,8 @@ SDValue DAGCombiner::matchVSelectOpSizesWithSetCC(SDNode *Cast) {
   unsigned CastOpcode = Cast->getOpcode();
   assert((CastOpcode == ISD::SIGN_EXTEND || CastOpcode == ISD::ZERO_EXTEND ||
           CastOpcode == ISD::TRUNCATE || CastOpcode == ISD::FP_EXTEND ||
-          CastOpcode == ISD::FP_ROUND) &&
+          CastOpcode == ISD::TRUNCATE_SSAT ||
+          CastOpcode == ISD::TRUNCATE_USAT || CastOpcode == ISD::FP_ROUND) &&
          "Unexpected opcode for vector select narrowing/widening");
 
   // We only do this transform before legal ops because the pattern may be
@@ -14867,6 +14872,119 @@ SDValue DAGCombiner::visitEXTEND_VECTOR_INREG(SDNode *N) {
   return SDValue();
 }
 
+SDValue DAGCombiner::visitTRUNCATE_USAT(SDNode *N) {
+  EVT VT = N->getValueType(0);
+  SDValue N0 = N->getOperand(0);
+  SDValue FPInstr = N0.getOpcode() == ISD::SMAX ? N0.getOperand(0) : N0;
+  if (FPInstr.getOpcode() == ISD::FP_TO_SINT ||
+      FPInstr.getOpcode() == ISD::FP_TO_UINT) {
+    EVT FPVT = FPInstr.getOperand(0).getValueType();
+    if (!DAG.getTargetLoweringInfo().shouldConvertFpToSat(ISD::FP_TO_UINT_SAT,
+                                                          FPVT, VT))
+      return SDValue();
+    SDValue Sat = DAG.getNode(ISD::FP_TO_UINT_SAT, SDLoc(FPInstr), VT,
+                              FPInstr.getOperand(0),
+                              DAG.getValueType(VT.getScalarType()));
+    return Sat;
+  }
+
+  return SDValue();
+}
+
+SDValue DAGCombiner::visitTRUNCATE_SSAT(SDNode *N) { return SDValue(); }
+
+/// Detect patterns of truncation with unsigned saturation:
+///
+/// 1. (truncate (umin (x, unsigned_max_of_dest_type)) to dest_type).
+///   Return the source value x to be truncated or SDValue() if the pattern was
+///   not matched.
+///
+/// 2. (truncate (smin (smax (x, C1), C2)) to dest_type),
+///   where C1 >= 0 and C2 is unsigned max of destination type.
+///
+///    (truncate (smax (smin (x, C2), C1)) to dest_type)
+///   where C1 >= 0, C2 is unsigned max of destination type and C1 <= C2.
+///
+///   These two patterns are equivalent to:
+///   (truncate (umin (smax(x, C1), unsigned_max_of_dest_type)) to dest_type)
+///   So return the smax(x, C1) value to be truncated or SDValue() if the
+///   pattern was not matched.
+static SDValue detectUSatPattern(SDValue In, EVT VT, SelectionDAG &DAG,
+                                 const SDLoc &DL) {
+  EVT InVT = In.getValueType();
+
+  // Saturation with truncation. We truncate from InVT to VT.
+  assert(InVT.getScalarSizeInBits() > VT.getScalarSizeInBits() &&
+         "Unexpected types for truncate operation");
+
+  // Match min/max and return limit value as a parameter.
+  auto MatchMinMax = [](SDValue V, unsigned Opcode, APInt &Limit) -> SDValue {
+    if (V.getOpcode() == Opcode &&
+        ISD::isConstantSplatVector(V.getOperand(1).getNode(), Limit))
+      return V.getOperand(0);
+    return SDValue();
+  };
+
+  APInt C1, C2;
+  if (SDValue UMin = MatchMinMax(In, ISD::UMIN, C2))
+    // C2 should be equal to UINT32_MAX / UINT16_MAX / UINT8_MAX according
+    // the element size of the destination type.
+    if (C2.isMask(VT.getScalarSizeInBits()))
+      return UMin;
+
+  if (SDValue SMin = MatchMinMax(In, ISD::SMIN, C2))
+    if (MatchMinMax(SMin, ISD::SMAX, C1))
+      if (C1.isNonNegative() && C2.isMask(VT.getScalarSizeInBits()))
+        return SMin;
+
+  if (SDValue SMax = MatchMinMax(In, ISD::SMAX, C1))
+    if (SDValue SMin = MatchMinMax(SMax, ISD::SMIN, C2))
+      if (C1.isNonNegative() && C2.isMask(VT.getScalarSizeInBits()) &&
+          C2.uge(C1))
+        return DAG.getNode(ISD::SMAX, DL, InVT, SMin, In.getOperand(1));
+
+  return SDValue();
+}
+
+/// Detect patterns of truncation with signed saturation:
+/// (truncate (smin ((smax (x, signed_min_of_dest_type)),
+///                  signed_max_of_dest_type)) to dest_type)
+/// or:
+/// (truncate (smax ((smin (x, signed_max_of_dest_type)),
+///                  signed_min_of_dest_type)) to dest_type).
+/// With MatchPackUS, the smax/smin range is [0, unsigned_max_of_dest_type].
+/// Return the source value to be truncated or SDValue() if the pattern was not
+/// matched.
+static SDValue detectSSatPattern(SDValue In, EVT VT) {
+  unsigned NumDstBits = VT.getScalarSizeInBits();
+  unsigned NumSrcBits = In.getScalarValueSizeInBits();
+  assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");
+
+  auto MatchMinMax = [](SDValue V, unsigned Opcode,
+                        const APInt &Limit) -> SDValue {
+    APInt C;
+    if (V.getOpcode() == Opcode &&
+        ISD::isConstantSplatVector(V.getOperand(1).getNode(), C) && C == Limit)
+      return V.getOperand(0);
+    return SDValue();
+  };
+
+  APInt SignedMax, SignedMin;
+  SignedMax = APInt::getSignedMaxValue(NumDstBits).sext(NumSrcBits);
+  SignedMin = APInt::getSignedMinValue(NumDstBits).sext(NumSrcBits);
+  if (SDValue SMin = MatchMinMax(In, ISD::SMIN, SignedMax)) {
+    if (SDValue SMax = MatchMinMax(SMin, ISD::SMAX, SignedMin)) {
+      return SMax;
+    }
+  }
+  if (SDValue SMax = MatchMinMax(In, ISD::SMAX, SignedMin)) {
+    if (SDValue SMin = MatchMinMax(SMax, ISD::SMIN, SignedMax)) {
+      return SMin;
+    }
+  }
+  return SDValue();
+}
+
 SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
   SDValue N0 = N->getOperand(0);
   EVT VT = N->getValueType(0);
@@ -14874,6 +14992,18 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
   bool isLE = DAG.getDataLayout().isLittleEndian();
   SDLoc DL(N);
 
+  if (!LegalOperations && N->getOpcode() == ISD::TRUNCATE) {
+    if (TLI.isOperationLegalOrCustom(ISD::TRUNCATE_SSAT, SrcVT)) {
+      if (SDValue SSatVal = detectSSatPattern(N0, VT))
+        return DAG.getNode(ISD::TRUNCATE_SSAT, DL, VT, SSatVal);
+    }
+
+    if (TLI.isOperationLegalOrCustom(ISD::TRUNCATE_USAT, SrcVT)) {
+      if (SDValue USatVal = detectUSatPattern(N0, VT, DAG, DL))
+        return DAG.getNode(ISD::TRUNCATE_USAT, DL, VT, USatVal);
+    }
+  }
+
   // trunc(undef) = undef
   if (N0.isUndef())
     return DAG.getUNDEF(VT);
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
index cc8de3a217f82..d3ad6c8acf4f1 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
@@ -380,6 +380,8 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
   case ISD::SIGN_EXTEND_VECTOR_INREG:   return "sign_extend_vector_inreg";
   case ISD::ZERO_EXTEND_VECTOR_INREG:   return "zero_extend_vector_inreg";
   case ISD::TRUNCATE:                   return "truncate";
+  case ISD::TRUNCATE_SSAT:              return "truncate_ssat";
+  case ISD::TRUNCATE_USAT:              return "truncate_usat";
   case ISD::FP_ROUND:                   return "fp_round";
   case ISD::STRICT_FP_ROUND:            return "strict_fp_round";
   case ISD::FP_EXTEND:                  return "fp_extend";
diff --git a/llvm/lib/CodeGen/TargetLoweringBase.cpp b/llvm/lib/CodeGen/TargetLoweringBase.cpp
index bf031c00a2449..3e855d5e450df 100644
--- a/llvm/lib/CodeGen/TargetLoweringBase.cpp
+++ b/llvm/lib/CodeGen/TargetLoweringBase.cpp
@@ -718,6 +718,10 @@ void TargetLoweringBase::initActions() {
     // Absolute difference
     setOperationAction({ISD::ABDS, ISD::ABDU}, VT, Expand);
 
+    // Saturated trunc
+    setOperationAction(ISD::TRUNCATE_SSAT, VT, Expand);
+    setOperationAction(ISD::TRUNCATE_USAT, VT, Expand);
+
     // These default to Expand so they will be expanded to CTLZ/CTTZ by default.
     setOperationAction({ISD::CTLZ_ZERO_UNDEF, ISD::CTTZ_ZERO_UNDEF}, VT,
                        Expand);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index df9b0ae1a632f..504bbaed1c8aa 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1274,6 +1274,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
       setOperationAction(ISD::AVGCEILU, VT, Legal);
       setOperationAction(ISD::ABDS, VT, Legal);
       setOperationAction(ISD::ABDU, VT, Legal);
+      setOperationAction(ISD::TRUNCATE_SSAT, VT, Legal);
+      setOperationAction(ISD::TRUNCATE_USAT, VT, Legal);
     }
 
     // Vector reductions
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index dd11f74882115..322219607407b 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -5343,9 +5343,13 @@ def VImm8000: PatLeaf<(AArch64mvni_msl (i32 127), (i32 264))>;
 // trunc(umin(X, 255)) -> UQXTRN v8i8
 def : Pat<(v8i8 (trunc (umin (v8i16 V128:$Vn), (v8i16 VImmFF)))),
           (UQXTNv8i8 V128:$Vn)>;
+def : Pat<(v8i8 (truncusat (v8i16 V128:$Vn))),
+          (UQXTNv8i8 V128:$Vn)>;
 // trunc(umin(X, 65535)) -> UQXTRN v4i16
 def : Pat<(v4i16 (trunc (umin (v4i32 V128:$Vn), (v4i32 VImmFFFF)))),
           (UQXTNv4i16 V128:$Vn)>;
+def : Pat<(v4i16 (truncusat (v4i32 V128:$Vn))),
+          (UQXTNv4i16 V128:$Vn)>;
 // trunc(smin(smax(X, -128), 128)) -> SQXTRN
 //  with reversed min/max
 def : Pat<(v8i8 (trunc (smin (smax (v8i16 V128:$Vn), (v8i16 VImm80)),
@@ -5354,6 +5358,8 @@ def : Pat<(v8i8 (trunc (smin (smax (v8i16 V128:$Vn), (v8i16 VImm80)),
 def : Pat<(v8i8 (trunc (smax (smin (v8i16 V128:$Vn), (v8i16 VImm7F)),
                              (v8i16 VImm80)))),
           (SQXTNv8i8 V128:$Vn)>;
+def : Pat<(v8i8 (truncssat (v8i16 V128:$Vn))),
+          (SQXTNv8i8 V128:$Vn)>;
 // trunc(smin(smax(X, -32768), 32767)) -> SQXTRN
 //  with reversed min/max
 def : Pat<(v4i16 (trunc (smin (smax (v4i32 V128:$Vn), (v4i32 VImm8000)),
@@ -5362,6 +5368,8 @@ def : Pat<(v4i16 (trunc (smin (smax (v4i32 V128:$Vn), (v4i32 VImm8000)),
 def : Pat<(v4i16 (trunc (smax (smin (v4i32 V128:$Vn), (v4i32 VImm7FFF)),
                               (v4i32 VImm8000)))),
           (SQXTNv4i16 V128:$Vn)>;
+def : Pat<(v4i16 (truncssat (v4i32 V128:$Vn))),
+          (SQXTNv4i16 V128:$Vn)>;
 
 // concat_vectors(Vd, trunc(smin(smax Vm, -128), 127) ~> SQXTN2(Vd, Vn)
 // with reversed min/max
@@ -5375,6 +5383,10 @@ def : Pat<(v16i8 (concat_vectors
                  (v8i8 (trunc (smax (smin (v8i16 V128:$Vn), (v8i16 VImm7F)),
                                           (v8i16 VImm80)))))),
           (SQXTNv16i8 (INSERT_SUBREG (IMPLICIT_DEF), V64:$Vd, dsub), V128:$Vn)>;
+def : Pat<(v16i8 (concat_vectors
+                 (v8i8 V64:$Vd),
+                 (v8i8 (truncssat (v8i16 V128:$Vn))))),
+          (SQXTNv16i8 (INSERT_SUBREG (IMPLICIT_DEF), V64:$Vd, dsub), V128:$Vn)>;
 
 // concat_vectors(Vd, trunc(smin(smax Vm, -32768), 32767) ~> SQXTN2(Vd, Vn)
 // with reversed min/max
@@ -5388,6 +5400,10 @@ def : Pat<(v8i16 (concat_vectors
                  (v4i16 (trunc (smax (smin (v4i32 V128:$Vn), (v4i32 VImm7FFF)),
                                            (v4i32 VImm8000)))))),
           (SQXTNv8i16 (INSERT_SUBREG (IMPLICIT_DEF), V64:$Vd, dsub), V128:$Vn)>;
+def : Pat<(v8i16 (concat_vectors
+                 (v4i16 V64:$Vd),
+                 (v4i16 (truncssat (v4i32 V128:$Vn))))),
+          (SQXTNv8i16 (INSERT_SUBREG (IMPLICIT_DEF), V64:$Vd, dsub), V128:$Vn)>;
 
 // Select BSWAP vector instructions into REV instructions
 def : Pat<(v4i16 (bswap (v4i16 V64:$Rn))), 
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 953196a586b6e..3b54416d1b5b2 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -853,7 +853,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
 
       // Integer VTs are lowered as a series of "RISCVISD::TRUNCATE_VECTOR_VL"
       // nodes which truncate by one power of two at a time.
-      setOperationAction(ISD::TRUNCATE, VT, Custom);
+      setOperationAction(
+          {ISD::TRUNCATE, ISD::TRUNCATE_SSAT, ISD::TRUNCATE_USAT}, VT, Custom);
 
       // Custom-lower insert/extract operations to simplify patterns.
       setOperationAction({ISD::INSERT_VECTOR_ELT, ISD::EXTRACT_VECTOR_ELT}, VT,
@@ -1168,7 +1169,9 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
 
         setOperationAction(ISD::SELECT, VT, Custom);
 
-        setOperationAction(ISD::TRUNCATE, VT, Custom);
+        setOperationAction(
+            {ISD::TRUNCATE, ISD::TRUNCATE_SSAT, ISD::TRUNCATE_USAT}, VT,
+            Custom);
 
         setOperationAction(ISD::BITCAST, VT, Custom);
 
@@ -1479,8 +1482,11 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
     setTargetDAGCombine({ISD::UMAX, ISD::UMIN, ISD::SMAX, ISD::SMIN});
 
   if ((Subtarget.hasStdExtZbs() && Subtarget.is64Bit()) ||
-      Subtarget.hasStdExtV())
+      Subtarget.hasStdExtV()) {
     setTargetDAGCombine(ISD::TRUNCATE);
+    setTargetDAGCombine(ISD::TRUNCATE_SSAT);
+    setTargetDAGCombine(ISD::TRUNCATE_USAT);
+  }
 
   if (Subtarget.hasStdExtZbkb())
     setTargetDAGCombine(ISD::BITREVERSE);
@@ -6092,7 +6098,7 @@ static bool hasMergeOp(unsigned Opcode) {
          Opcode <= RISCVISD::LAST_RISCV_STRICTFP_OPCODE &&
          "not a RISC-V target specific op");
   static_assert(RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP ==
-                    130 &&
+                    132 &&
                 RISCVISD::LAST_RISCV_STRICTFP_OPCODE -
                         ISD::FIRST_TARGET_STRICTFP_OPCODE ==
                     21 &&
@@ -6118,7 +6124,7 @@ static bool hasMaskOp(unsigned Opcode) {
          Opcode <= RISCVISD::LAST_RISCV_STRICTFP_OPCODE &&
          "not a RISC-V target specific op");
   static_assert(RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP ==
-                    130 &&
+                    132 &&
                 RISCVISD::LAST_RISCV_STRICTFP_OPCODE -
                         ISD::FIRST_TARGET_STRICTFP_OPCODE ==
                     21 &&
@@ -6389,6 +6395,8 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
     return DAG.getNode(RISCVISD::BREV8, DL, VT, BSwap);
   }
   case ISD::TRUNCATE:
+  case ISD::TRUNCATE_SSAT:
+  case ISD::TRUNCATE_USAT:
     // Only custom-lower vector truncates
     if (!Op.getSimpleValueType().isVector())
       return Op;
@@ -8275,11 +8283,15 @@ SDValue RISCVTargetLowering::lowerVectorTruncLike(SDValue Op,
 
   LLVMContext &Context = *DAG.getContext();
   const ElementCount Count = ContainerVT.getVectorElementCount();
+  unsigned NewOpc = RISCVISD::TRUNCATE_VECTOR_VL;
+  if (Op.getOpcode() == ISD::TRUNCATE_SSAT)
+    NewOpc = RISCVISD::TRUNCATE_VECTOR_VL_SSAT;
+  else if (Op.getOpcode() == ISD::TRUNCATE_USAT)
+    NewOpc = RISCVISD::TRUNCATE_VECTOR_VL_USAT;
   do {
     SrcEltVT = MVT::getIntegerVT(SrcEltVT.getSizeInBits() / 2);
     EVT ResultVT = EVT::getVectorVT(Context, SrcEltVT, Count);
-    Result = DAG.getNode(RISCVISD::TRUNCATE_VECTOR_VL, DL, ResultVT, Result,
-                         Mask, VL);
+    Result = DAG.getNode(NewOpc, DL, ResultVT, Result, Mask, VL);
   } while (SrcEltVT != DstEltVT);
 
   if (SrcVT.isFixedLengthVector())
@@ -16512,7 +16524,9 @@ static SDValue combineTruncOfSraSext(SDNode *N, SelectionDAG &DAG) {
 // minimum value.
 static SDValue combineTruncToVnclip(SDNode *N, SelectionDAG &DAG,
                                     const RISCVSubtarget &Subtarget) {
-  assert(N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL);
+  assert(N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL ||
+         N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL_SSAT ||
+         N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL_USAT);
 
   MVT VT = N->getSimpleValueType(0);
 
@@ -16617,9 +16631,11 @@ static SDValue combineTruncToVnclip(SDNode *N, SelectionDAG &DAG,
 
   SDValue Val;
   unsigned ClipOpc;
-  if ((Val = DetectUSatPattern(Src)))
+
+  Val = N->getOperand(0);
+  if (N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL_USAT)
     ClipOpc = RISCVISD::VNCLIPU_VL;
-  else if ((Val = DetectSSatPattern(Src)))
+  else if (N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL_SSAT)
     ClipOpc = RISCVISD::VNCLIP_VL;
   else
     return SDValue();
@@ -16857,6 +16873,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
     }
     return SDValue();
   case RISCVISD::TRUNCATE_VECTOR_VL:
+  case RISCVISD::TRUNCATE_VECTOR_VL_SSAT:
+  case RISCVISD::TRUNCATE_VECTOR_VL_USAT:
     if (SDValue V = combineTruncOfSraSext(N, DAG))
       return V;
     return combineTruncToVnclip(N, DAG, Subtarget);
@@ -20433,6 +20451,8 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
   NODE_NAME_CASE(SPLAT_VECTOR_SPLIT_I64_VL)
   NODE_NAME_CASE(READ_VLENB)
   NODE_NAME_CASE(TRUNCATE_VECTOR_VL)
+  NODE_NAME_CASE(TRUNCATE_VECTOR_VL_SSAT)
+  NODE_NAME_CASE(TRUNCATE_VECTOR_VL_USAT)
   NODE_NAME_CASE(VSLIDEUP_VL)
   NODE_NAME_CASE(VSLIDE1UP_VL)
   NODE_NAME_CASE(VSLIDEDOWN_VL)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 0b0ad9229f0b3..3d582fcdaf64b 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -181,6 +181,8 @@ enum NodeType : unsigned {
   // Truncates a RVV integer vector by one power-of-two. Carries both an extra
   // mask and VL operand.
   TRUNCATE_VECTOR_VL,
+  TRUNCATE_VECTOR_VL_SSAT,
+  TRUNCATE_VECTOR_VL_USAT,
   // Matches the semantics of vslideup/vslidedown. The first operand is the
   // pass-thru operand, the second is the source vector, the third is the XLenVT
   // index (either constant or non-constant), the fourth is the mask, the fifth

@llvmbot
Copy link
Member

llvmbot commented Jul 18, 2024

@llvm/pr-subscribers-llvm-selectiondag

Author: hanbeom (ParkHanbum)

Changes

truncate is saturated if no additional conversion is required
between the target and return values. if the target is saturated
when attempting to crop from a vector, there is an opportunity
to optimize it.

previously, each architecture had an attemping optimization, so there
was redundant code.

this patch implements common logic by adding ISD::TRUNCATE_[US]SAT
to indicate saturated truncate.


Full diff: https://github.com/llvm/llvm-project/pull/99418.diff

9 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/ISDOpcodes.h (+3)
  • (modified) llvm/include/llvm/Target/TargetSelectionDAG.td (+2)
  • (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+131-1)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp (+2)
  • (modified) llvm/lib/CodeGen/TargetLoweringBase.cpp (+4)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+2)
  • (modified) llvm/lib/Target/AArch64/AArch64InstrInfo.td (+16)
  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+30-10)
  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.h (+2)
diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index e6b10209b4767..0b36e5b40da73 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -804,6 +804,9 @@ enum NodeType {
 
   /// TRUNCATE - Completely drop the high bits.
   TRUNCATE,
+  /// TRUNCATE_[SU]SAT - Truncate for saturated operand
+  TRUNCATE_SSAT,
+  TRUNCATE_USAT,
 
   /// [SU]INT_TO_FP - These operators convert integers (whose interpreted sign
   /// depends on the first letter) to floating point.
diff --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td
index 133c9b113e51b..a5242694c9507 100644
--- a/llvm/include/llvm/Target/TargetSelectionDAG.td
+++ b/llvm/include/llvm/Target/TargetSelectionDAG.td
@@ -471,6 +471,8 @@ def sext       : SDNode<"ISD::SIGN_EXTEND", SDTIntExtendOp>;
 def zext       : SDNode<"ISD::ZERO_EXTEND", SDTIntExtendOp>;
 def anyext     : SDNode<"ISD::ANY_EXTEND" , SDTIntExtendOp>;
 def trunc      : SDNode<"ISD::TRUNCATE"   , SDTIntTruncOp>;
+def truncssat  : SDNode<"ISD::TRUNCATE_SSAT", SDTIntTruncOp>;
+def truncusat  : SDNode<"ISD::TRUNCATE_USAT", SDTIntTruncOp>;
 def bitconvert : SDNode<"ISD::BITCAST"    , SDTUnaryOp>;
 def addrspacecast : SDNode<"ISD::ADDRSPACECAST", SDTUnaryOp>;
 def freeze     : SDNode<"ISD::FREEZE"     , SDTFreeze>;
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 302ad128f4f53..967f313c9885e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -486,6 +486,8 @@ namespace {
     SDValue visitSIGN_EXTEND_INREG(SDNode *N);
     SDValue visitEXTEND_VECTOR_INREG(SDNode *N);
     SDValue visitTRUNCATE(SDNode *N);
+    SDValue visitTRUNCATE_SSAT(SDNode *N);
+    SDValue visitTRUNCATE_USAT(SDNode *N);
     SDValue visitBITCAST(SDNode *N);
     SDValue visitFREEZE(SDNode *N);
     SDValue visitBUILD_PAIR(SDNode *N);
@@ -1907,6 +1909,8 @@ SDValue DAGCombiner::visit(SDNode *N) {
   case ISD::ZERO_EXTEND_VECTOR_INREG:
   case ISD::ANY_EXTEND_VECTOR_INREG: return visitEXTEND_VECTOR_INREG(N);
   case ISD::TRUNCATE:           return visitTRUNCATE(N);
+  case ISD::TRUNCATE_SSAT:      return visitTRUNCATE_SSAT(N);
+  case ISD::TRUNCATE_USAT:      return visitTRUNCATE_USAT(N);
   case ISD::BITCAST:            return visitBITCAST(N);
   case ISD::BUILD_PAIR:         return visitBUILD_PAIR(N);
   case ISD::FADD:               return visitFADD(N);
@@ -13154,7 +13158,8 @@ SDValue DAGCombiner::matchVSelectOpSizesWithSetCC(SDNode *Cast) {
   unsigned CastOpcode = Cast->getOpcode();
   assert((CastOpcode == ISD::SIGN_EXTEND || CastOpcode == ISD::ZERO_EXTEND ||
           CastOpcode == ISD::TRUNCATE || CastOpcode == ISD::FP_EXTEND ||
-          CastOpcode == ISD::FP_ROUND) &&
+          CastOpcode == ISD::TRUNCATE_SSAT ||
+          CastOpcode == ISD::TRUNCATE_USAT || CastOpcode == ISD::FP_ROUND) &&
          "Unexpected opcode for vector select narrowing/widening");
 
   // We only do this transform before legal ops because the pattern may be
@@ -14867,6 +14872,119 @@ SDValue DAGCombiner::visitEXTEND_VECTOR_INREG(SDNode *N) {
   return SDValue();
 }
 
+SDValue DAGCombiner::visitTRUNCATE_USAT(SDNode *N) {
+  EVT VT = N->getValueType(0);
+  SDValue N0 = N->getOperand(0);
+  SDValue FPInstr = N0.getOpcode() == ISD::SMAX ? N0.getOperand(0) : N0;
+  if (FPInstr.getOpcode() == ISD::FP_TO_SINT ||
+      FPInstr.getOpcode() == ISD::FP_TO_UINT) {
+    EVT FPVT = FPInstr.getOperand(0).getValueType();
+    if (!DAG.getTargetLoweringInfo().shouldConvertFpToSat(ISD::FP_TO_UINT_SAT,
+                                                          FPVT, VT))
+      return SDValue();
+    SDValue Sat = DAG.getNode(ISD::FP_TO_UINT_SAT, SDLoc(FPInstr), VT,
+                              FPInstr.getOperand(0),
+                              DAG.getValueType(VT.getScalarType()));
+    return Sat;
+  }
+
+  return SDValue();
+}
+
+SDValue DAGCombiner::visitTRUNCATE_SSAT(SDNode *N) { return SDValue(); }
+
+/// Detect patterns of truncation with unsigned saturation:
+///
+/// 1. (truncate (umin (x, unsigned_max_of_dest_type)) to dest_type).
+///   Return the source value x to be truncated or SDValue() if the pattern was
+///   not matched.
+///
+/// 2. (truncate (smin (smax (x, C1), C2)) to dest_type),
+///   where C1 >= 0 and C2 is unsigned max of destination type.
+///
+///    (truncate (smax (smin (x, C2), C1)) to dest_type)
+///   where C1 >= 0, C2 is unsigned max of destination type and C1 <= C2.
+///
+///   These two patterns are equivalent to:
+///   (truncate (umin (smax(x, C1), unsigned_max_of_dest_type)) to dest_type)
+///   So return the smax(x, C1) value to be truncated or SDValue() if the
+///   pattern was not matched.
+static SDValue detectUSatPattern(SDValue In, EVT VT, SelectionDAG &DAG,
+                                 const SDLoc &DL) {
+  EVT InVT = In.getValueType();
+
+  // Saturation with truncation. We truncate from InVT to VT.
+  assert(InVT.getScalarSizeInBits() > VT.getScalarSizeInBits() &&
+         "Unexpected types for truncate operation");
+
+  // Match min/max and return limit value as a parameter.
+  auto MatchMinMax = [](SDValue V, unsigned Opcode, APInt &Limit) -> SDValue {
+    if (V.getOpcode() == Opcode &&
+        ISD::isConstantSplatVector(V.getOperand(1).getNode(), Limit))
+      return V.getOperand(0);
+    return SDValue();
+  };
+
+  APInt C1, C2;
+  if (SDValue UMin = MatchMinMax(In, ISD::UMIN, C2))
+    // C2 should be equal to UINT32_MAX / UINT16_MAX / UINT8_MAX according
+    // the element size of the destination type.
+    if (C2.isMask(VT.getScalarSizeInBits()))
+      return UMin;
+
+  if (SDValue SMin = MatchMinMax(In, ISD::SMIN, C2))
+    if (MatchMinMax(SMin, ISD::SMAX, C1))
+      if (C1.isNonNegative() && C2.isMask(VT.getScalarSizeInBits()))
+        return SMin;
+
+  if (SDValue SMax = MatchMinMax(In, ISD::SMAX, C1))
+    if (SDValue SMin = MatchMinMax(SMax, ISD::SMIN, C2))
+      if (C1.isNonNegative() && C2.isMask(VT.getScalarSizeInBits()) &&
+          C2.uge(C1))
+        return DAG.getNode(ISD::SMAX, DL, InVT, SMin, In.getOperand(1));
+
+  return SDValue();
+}
+
+/// Detect patterns of truncation with signed saturation:
+/// (truncate (smin ((smax (x, signed_min_of_dest_type)),
+///                  signed_max_of_dest_type)) to dest_type)
+/// or:
+/// (truncate (smax ((smin (x, signed_max_of_dest_type)),
+///                  signed_min_of_dest_type)) to dest_type).
+/// With MatchPackUS, the smax/smin range is [0, unsigned_max_of_dest_type].
+/// Return the source value to be truncated or SDValue() if the pattern was not
+/// matched.
+static SDValue detectSSatPattern(SDValue In, EVT VT) {
+  unsigned NumDstBits = VT.getScalarSizeInBits();
+  unsigned NumSrcBits = In.getScalarValueSizeInBits();
+  assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");
+
+  auto MatchMinMax = [](SDValue V, unsigned Opcode,
+                        const APInt &Limit) -> SDValue {
+    APInt C;
+    if (V.getOpcode() == Opcode &&
+        ISD::isConstantSplatVector(V.getOperand(1).getNode(), C) && C == Limit)
+      return V.getOperand(0);
+    return SDValue();
+  };
+
+  APInt SignedMax, SignedMin;
+  SignedMax = APInt::getSignedMaxValue(NumDstBits).sext(NumSrcBits);
+  SignedMin = APInt::getSignedMinValue(NumDstBits).sext(NumSrcBits);
+  if (SDValue SMin = MatchMinMax(In, ISD::SMIN, SignedMax)) {
+    if (SDValue SMax = MatchMinMax(SMin, ISD::SMAX, SignedMin)) {
+      return SMax;
+    }
+  }
+  if (SDValue SMax = MatchMinMax(In, ISD::SMAX, SignedMin)) {
+    if (SDValue SMin = MatchMinMax(SMax, ISD::SMIN, SignedMax)) {
+      return SMin;
+    }
+  }
+  return SDValue();
+}
+
 SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
   SDValue N0 = N->getOperand(0);
   EVT VT = N->getValueType(0);
@@ -14874,6 +14992,18 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
   bool isLE = DAG.getDataLayout().isLittleEndian();
   SDLoc DL(N);
 
+  if (!LegalOperations && N->getOpcode() == ISD::TRUNCATE) {
+    if (TLI.isOperationLegalOrCustom(ISD::TRUNCATE_SSAT, SrcVT)) {
+      if (SDValue SSatVal = detectSSatPattern(N0, VT))
+        return DAG.getNode(ISD::TRUNCATE_SSAT, DL, VT, SSatVal);
+    }
+
+    if (TLI.isOperationLegalOrCustom(ISD::TRUNCATE_USAT, SrcVT)) {
+      if (SDValue USatVal = detectUSatPattern(N0, VT, DAG, DL))
+        return DAG.getNode(ISD::TRUNCATE_USAT, DL, VT, USatVal);
+    }
+  }
+
   // trunc(undef) = undef
   if (N0.isUndef())
     return DAG.getUNDEF(VT);
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
index cc8de3a217f82..d3ad6c8acf4f1 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
@@ -380,6 +380,8 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
   case ISD::SIGN_EXTEND_VECTOR_INREG:   return "sign_extend_vector_inreg";
   case ISD::ZERO_EXTEND_VECTOR_INREG:   return "zero_extend_vector_inreg";
   case ISD::TRUNCATE:                   return "truncate";
+  case ISD::TRUNCATE_SSAT:              return "truncate_ssat";
+  case ISD::TRUNCATE_USAT:              return "truncate_usat";
   case ISD::FP_ROUND:                   return "fp_round";
   case ISD::STRICT_FP_ROUND:            return "strict_fp_round";
   case ISD::FP_EXTEND:                  return "fp_extend";
diff --git a/llvm/lib/CodeGen/TargetLoweringBase.cpp b/llvm/lib/CodeGen/TargetLoweringBase.cpp
index bf031c00a2449..3e855d5e450df 100644
--- a/llvm/lib/CodeGen/TargetLoweringBase.cpp
+++ b/llvm/lib/CodeGen/TargetLoweringBase.cpp
@@ -718,6 +718,10 @@ void TargetLoweringBase::initActions() {
     // Absolute difference
     setOperationAction({ISD::ABDS, ISD::ABDU}, VT, Expand);
 
+    // Saturated trunc
+    setOperationAction(ISD::TRUNCATE_SSAT, VT, Expand);
+    setOperationAction(ISD::TRUNCATE_USAT, VT, Expand);
+
     // These default to Expand so they will be expanded to CTLZ/CTTZ by default.
     setOperationAction({ISD::CTLZ_ZERO_UNDEF, ISD::CTTZ_ZERO_UNDEF}, VT,
                        Expand);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index df9b0ae1a632f..504bbaed1c8aa 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1274,6 +1274,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
       setOperationAction(ISD::AVGCEILU, VT, Legal);
       setOperationAction(ISD::ABDS, VT, Legal);
       setOperationAction(ISD::ABDU, VT, Legal);
+      setOperationAction(ISD::TRUNCATE_SSAT, VT, Legal);
+      setOperationAction(ISD::TRUNCATE_USAT, VT, Legal);
     }
 
     // Vector reductions
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index dd11f74882115..322219607407b 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -5343,9 +5343,13 @@ def VImm8000: PatLeaf<(AArch64mvni_msl (i32 127), (i32 264))>;
 // trunc(umin(X, 255)) -> UQXTRN v8i8
 def : Pat<(v8i8 (trunc (umin (v8i16 V128:$Vn), (v8i16 VImmFF)))),
           (UQXTNv8i8 V128:$Vn)>;
+def : Pat<(v8i8 (truncusat (v8i16 V128:$Vn))),
+          (UQXTNv8i8 V128:$Vn)>;
 // trunc(umin(X, 65535)) -> UQXTRN v4i16
 def : Pat<(v4i16 (trunc (umin (v4i32 V128:$Vn), (v4i32 VImmFFFF)))),
           (UQXTNv4i16 V128:$Vn)>;
+def : Pat<(v4i16 (truncusat (v4i32 V128:$Vn))),
+          (UQXTNv4i16 V128:$Vn)>;
 // trunc(smin(smax(X, -128), 128)) -> SQXTRN
 //  with reversed min/max
 def : Pat<(v8i8 (trunc (smin (smax (v8i16 V128:$Vn), (v8i16 VImm80)),
@@ -5354,6 +5358,8 @@ def : Pat<(v8i8 (trunc (smin (smax (v8i16 V128:$Vn), (v8i16 VImm80)),
 def : Pat<(v8i8 (trunc (smax (smin (v8i16 V128:$Vn), (v8i16 VImm7F)),
                              (v8i16 VImm80)))),
           (SQXTNv8i8 V128:$Vn)>;
+def : Pat<(v8i8 (truncssat (v8i16 V128:$Vn))),
+          (SQXTNv8i8 V128:$Vn)>;
 // trunc(smin(smax(X, -32768), 32767)) -> SQXTRN
 //  with reversed min/max
 def : Pat<(v4i16 (trunc (smin (smax (v4i32 V128:$Vn), (v4i32 VImm8000)),
@@ -5362,6 +5368,8 @@ def : Pat<(v4i16 (trunc (smin (smax (v4i32 V128:$Vn), (v4i32 VImm8000)),
 def : Pat<(v4i16 (trunc (smax (smin (v4i32 V128:$Vn), (v4i32 VImm7FFF)),
                               (v4i32 VImm8000)))),
           (SQXTNv4i16 V128:$Vn)>;
+def : Pat<(v4i16 (truncssat (v4i32 V128:$Vn))),
+          (SQXTNv4i16 V128:$Vn)>;
 
 // concat_vectors(Vd, trunc(smin(smax Vm, -128), 127) ~> SQXTN2(Vd, Vn)
 // with reversed min/max
@@ -5375,6 +5383,10 @@ def : Pat<(v16i8 (concat_vectors
                  (v8i8 (trunc (smax (smin (v8i16 V128:$Vn), (v8i16 VImm7F)),
                                           (v8i16 VImm80)))))),
           (SQXTNv16i8 (INSERT_SUBREG (IMPLICIT_DEF), V64:$Vd, dsub), V128:$Vn)>;
+def : Pat<(v16i8 (concat_vectors
+                 (v8i8 V64:$Vd),
+                 (v8i8 (truncssat (v8i16 V128:$Vn))))),
+          (SQXTNv16i8 (INSERT_SUBREG (IMPLICIT_DEF), V64:$Vd, dsub), V128:$Vn)>;
 
 // concat_vectors(Vd, trunc(smin(smax Vm, -32768), 32767) ~> SQXTN2(Vd, Vn)
 // with reversed min/max
@@ -5388,6 +5400,10 @@ def : Pat<(v8i16 (concat_vectors
                  (v4i16 (trunc (smax (smin (v4i32 V128:$Vn), (v4i32 VImm7FFF)),
                                            (v4i32 VImm8000)))))),
           (SQXTNv8i16 (INSERT_SUBREG (IMPLICIT_DEF), V64:$Vd, dsub), V128:$Vn)>;
+def : Pat<(v8i16 (concat_vectors
+                 (v4i16 V64:$Vd),
+                 (v4i16 (truncssat (v4i32 V128:$Vn))))),
+          (SQXTNv8i16 (INSERT_SUBREG (IMPLICIT_DEF), V64:$Vd, dsub), V128:$Vn)>;
 
 // Select BSWAP vector instructions into REV instructions
 def : Pat<(v4i16 (bswap (v4i16 V64:$Rn))), 
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 953196a586b6e..3b54416d1b5b2 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -853,7 +853,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
 
       // Integer VTs are lowered as a series of "RISCVISD::TRUNCATE_VECTOR_VL"
       // nodes which truncate by one power of two at a time.
-      setOperationAction(ISD::TRUNCATE, VT, Custom);
+      setOperationAction(
+          {ISD::TRUNCATE, ISD::TRUNCATE_SSAT, ISD::TRUNCATE_USAT}, VT, Custom);
 
       // Custom-lower insert/extract operations to simplify patterns.
       setOperationAction({ISD::INSERT_VECTOR_ELT, ISD::EXTRACT_VECTOR_ELT}, VT,
@@ -1168,7 +1169,9 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
 
         setOperationAction(ISD::SELECT, VT, Custom);
 
-        setOperationAction(ISD::TRUNCATE, VT, Custom);
+        setOperationAction(
+            {ISD::TRUNCATE, ISD::TRUNCATE_SSAT, ISD::TRUNCATE_USAT}, VT,
+            Custom);
 
         setOperationAction(ISD::BITCAST, VT, Custom);
 
@@ -1479,8 +1482,11 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
     setTargetDAGCombine({ISD::UMAX, ISD::UMIN, ISD::SMAX, ISD::SMIN});
 
   if ((Subtarget.hasStdExtZbs() && Subtarget.is64Bit()) ||
-      Subtarget.hasStdExtV())
+      Subtarget.hasStdExtV()) {
     setTargetDAGCombine(ISD::TRUNCATE);
+    setTargetDAGCombine(ISD::TRUNCATE_SSAT);
+    setTargetDAGCombine(ISD::TRUNCATE_USAT);
+  }
 
   if (Subtarget.hasStdExtZbkb())
     setTargetDAGCombine(ISD::BITREVERSE);
@@ -6092,7 +6098,7 @@ static bool hasMergeOp(unsigned Opcode) {
          Opcode <= RISCVISD::LAST_RISCV_STRICTFP_OPCODE &&
          "not a RISC-V target specific op");
   static_assert(RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP ==
-                    130 &&
+                    132 &&
                 RISCVISD::LAST_RISCV_STRICTFP_OPCODE -
                         ISD::FIRST_TARGET_STRICTFP_OPCODE ==
                     21 &&
@@ -6118,7 +6124,7 @@ static bool hasMaskOp(unsigned Opcode) {
          Opcode <= RISCVISD::LAST_RISCV_STRICTFP_OPCODE &&
          "not a RISC-V target specific op");
   static_assert(RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP ==
-                    130 &&
+                    132 &&
                 RISCVISD::LAST_RISCV_STRICTFP_OPCODE -
                         ISD::FIRST_TARGET_STRICTFP_OPCODE ==
                     21 &&
@@ -6389,6 +6395,8 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
     return DAG.getNode(RISCVISD::BREV8, DL, VT, BSwap);
   }
   case ISD::TRUNCATE:
+  case ISD::TRUNCATE_SSAT:
+  case ISD::TRUNCATE_USAT:
     // Only custom-lower vector truncates
     if (!Op.getSimpleValueType().isVector())
       return Op;
@@ -8275,11 +8283,15 @@ SDValue RISCVTargetLowering::lowerVectorTruncLike(SDValue Op,
 
   LLVMContext &Context = *DAG.getContext();
   const ElementCount Count = ContainerVT.getVectorElementCount();
+  unsigned NewOpc = RISCVISD::TRUNCATE_VECTOR_VL;
+  if (Op.getOpcode() == ISD::TRUNCATE_SSAT)
+    NewOpc = RISCVISD::TRUNCATE_VECTOR_VL_SSAT;
+  else if (Op.getOpcode() == ISD::TRUNCATE_USAT)
+    NewOpc = RISCVISD::TRUNCATE_VECTOR_VL_USAT;
   do {
     SrcEltVT = MVT::getIntegerVT(SrcEltVT.getSizeInBits() / 2);
     EVT ResultVT = EVT::getVectorVT(Context, SrcEltVT, Count);
-    Result = DAG.getNode(RISCVISD::TRUNCATE_VECTOR_VL, DL, ResultVT, Result,
-                         Mask, VL);
+    Result = DAG.getNode(NewOpc, DL, ResultVT, Result, Mask, VL);
   } while (SrcEltVT != DstEltVT);
 
   if (SrcVT.isFixedLengthVector())
@@ -16512,7 +16524,9 @@ static SDValue combineTruncOfSraSext(SDNode *N, SelectionDAG &DAG) {
 // minimum value.
 static SDValue combineTruncToVnclip(SDNode *N, SelectionDAG &DAG,
                                     const RISCVSubtarget &Subtarget) {
-  assert(N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL);
+  assert(N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL ||
+         N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL_SSAT ||
+         N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL_USAT);
 
   MVT VT = N->getSimpleValueType(0);
 
@@ -16617,9 +16631,11 @@ static SDValue combineTruncToVnclip(SDNode *N, SelectionDAG &DAG,
 
   SDValue Val;
   unsigned ClipOpc;
-  if ((Val = DetectUSatPattern(Src)))
+
+  Val = N->getOperand(0);
+  if (N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL_USAT)
     ClipOpc = RISCVISD::VNCLIPU_VL;
-  else if ((Val = DetectSSatPattern(Src)))
+  else if (N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL_SSAT)
     ClipOpc = RISCVISD::VNCLIP_VL;
   else
     return SDValue();
@@ -16857,6 +16873,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
     }
     return SDValue();
   case RISCVISD::TRUNCATE_VECTOR_VL:
+  case RISCVISD::TRUNCATE_VECTOR_VL_SSAT:
+  case RISCVISD::TRUNCATE_VECTOR_VL_USAT:
     if (SDValue V = combineTruncOfSraSext(N, DAG))
       return V;
     return combineTruncToVnclip(N, DAG, Subtarget);
@@ -20433,6 +20451,8 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
   NODE_NAME_CASE(SPLAT_VECTOR_SPLIT_I64_VL)
   NODE_NAME_CASE(READ_VLENB)
   NODE_NAME_CASE(TRUNCATE_VECTOR_VL)
+  NODE_NAME_CASE(TRUNCATE_VECTOR_VL_SSAT)
+  NODE_NAME_CASE(TRUNCATE_VECTOR_VL_USAT)
   NODE_NAME_CASE(VSLIDEUP_VL)
   NODE_NAME_CASE(VSLIDE1UP_VL)
   NODE_NAME_CASE(VSLIDEDOWN_VL)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 0b0ad9229f0b3..3d582fcdaf64b 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -181,6 +181,8 @@ enum NodeType : unsigned {
   // Truncates a RVV integer vector by one power-of-two. Carries both an extra
   // mask and VL operand.
   TRUNCATE_VECTOR_VL,
+  TRUNCATE_VECTOR_VL_SSAT,
+  TRUNCATE_VECTOR_VL_USAT,
   // Matches the semantics of vslideup/vslidedown. The first operand is the
   // pass-thru operand, the second is the source vector, the third is the XLenVT
   // index (either constant or non-constant), the fourth is the mask, the fifth

@@ -804,6 +804,9 @@ enum NodeType {

/// TRUNCATE - Completely drop the high bits.
TRUNCATE,
/// TRUNCATE_[SU]SAT - Truncate for saturated operand
TRUNCATE_SSAT,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is TRUNCATE_SSAT signed input to signed result?
Is TRUNCATE_USAT unsigned input to unsigned result?

I ask because X86's packuswb instructon does signed input to unsigned result so its worth being clearing here.

Copy link
Contributor Author

@ParkHanbum ParkHanbum Jul 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is my opinion and needs confirmation.
It is true that SSAT/USAT is seperated for indicate signed, but 'truncate_[us]sat' means that the range of values for 'truncate' is in the range of values that don't have to care about the sign bit.

We also don't support unsigned type variables, right?
so, I think it is ok that 'truncate_ssat' instruction doesn't care about the type of target value, just return it to result type.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variables don't have signed or unsigned type but operations can interpret their input as signed or unsigned.

I believe as you have defined it TRUNCATE_USAT will interpret the input as an unsigned value and produce an unsigned result in the destination type.

The x86 packuswb interprets the input as signed 16 bits and produces an unsigned 8 bit result. If the input is negative it will return 0. That is a different operation than either operation defined here. It's equivalent to smax with 0 followed by truncate_usat. I could imagine having that operation as a single node. Not suggesting for this patch.

I just want the semantics of the new opcodes documented.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

@RKSimon RKSimon Jul 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So should we have 3 truncsat nodes? I mentioned this on #85903 but hadn't realised other targets had something similar to PACKUS

TRUNCATE_SSAT_S,  // saturate signed input to signed result - truncate(smin(smax(x)))
TRUNCATE_SSAT_U,  // saturate signed input to unsigned result - truncate(smin(smax(x,0)))
TRUNCATE_USAT_U,  // saturate unsigned input to unsigned result - truncate(umin(x))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@RKSimon should add that ISD? is it not necessary for unsigned input to signed input?

Copy link
Collaborator

@RKSimon RKSimon Jul 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we have targets that support TRUNCATE_USAT_S then by all means add it.

return SDValue();
}

SDValue DAGCombiner::visitTRUNCATE_SSAT(SDNode *N) { return SDValue(); }
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just not add the visit function instead of adding an empty function?

@@ -5343,9 +5343,13 @@ def VImm8000: PatLeaf<(AArch64mvni_msl (i32 127), (i32 264))>;
// trunc(umin(X, 255)) -> UQXTRN v8i8
def : Pat<(v8i8 (trunc (umin (v8i16 V128:$Vn), (v8i16 VImmFF)))),
(UQXTNv8i8 V128:$Vn)>;
def : Pat<(v8i8 (truncusat (v8i16 V128:$Vn))),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can the pattern above be dropped?

Copy link
Contributor Author

@ParkHanbum ParkHanbum Jul 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not possible with the current logic. However, I think it would be possible with additional logic. Would you like to see it included in this patch?

SelectionDAG has 22 nodes:
  t0: ch,glue = EntryToken
              t2: v8f16,ch = CopyFromReg t0, Register:v8f16 %0
            t16: v8i16 = fp_to_sint_sat t2, ValueType:ch:i16
          t19: v8i16 = smin t16, t18
        t22: v8i16 = smax t19, t21
      t23: v8i8 = truncate t22
              t4: v8f16,ch = CopyFromReg t0, Register:v8f16 %1
            t24: v8i16 = fp_to_sint_sat t4, ValueType:ch:i16
          t25: v8i16 = smin t24, t18
        t26: v8i16 = smax t25, t21
      t27: v8i8 = truncate t26
    t13: v16i8 = concat_vectors t23, t27
  t9: ch,glue = CopyToReg t0, Register:v16i8 $q0, t13

I think it need support for fp_to_[su]int_sat.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are the t18 and t21 nodes. It looks like you only pasted part of the DAG.

Copy link
Contributor Author

@ParkHanbum ParkHanbum Jul 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@topperc

1173:SelectionDAG has 22 nodes:
1174-  t0: ch,glue = EntryToken
1175-              t2: v8f16,ch = CopyFromReg t0, Register:v8f16 %0
1176-            t16: v8i16 = fp_to_sint_sat t2, ValueType:ch:i16
1177-          t19: v8i16 = smin t16, t18
1178-        t22: v8i16 = smax t19, t21
1179-      t23: v8i8 = truncate t22
1180-              t4: v8f16,ch = CopyFromReg t0, Register:v8f16 %1
1181-            t24: v8i16 = fp_to_sint_sat t4, ValueType:ch:i16
1182-          t25: v8i16 = smin t24, t18
1183-        t26: v8i16 = smax t25, t21
1184-      t27: v8i8 = truncate t26
1185-    t13: v16i8 = concat_vectors t23, t27
1186-  t9: ch,glue = CopyToReg t0, Register:v16i8 $q0, t13
1187-  t18: v8i16 = BUILD_VECTOR Constant:i32<127>, Constant:i32<127>, Constant:i32<127>, Constant:i32<127>, Constant:i32<127>, Constant:i32<127>, Constant:i32<127>, Constant:i32<127>
1188-  t21: v8i16 = BUILD_VECTOR Constant:i32<65408>, Constant:i32<65408>, Constant:i32<65408>, Constant:i32<65408>, Constant:i32<65408>, Constant:i32<65408>, Constant:i32<65408>, Constant:i32<65408>
1189-  t10: ch = AArch64ISD::RET_GLUE t9, Register:v16i8 $q0, t9:1

It was at the bottom.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@topperc I tested this by implementing additional code to support fp_to_[su]int_sat. As a result, if LegalOperations is unchecked, the associated Pattern can be deleted.

Should we continue to check LegalOperations? I can not sure because I'm not a veteran.

if ((Val = DetectUSatPattern(Src)))

Val = N->getOperand(0);
if (N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL_USAT)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is TRUNCATE_VECTOR_VL_USAT the same as RISCVISD::VNCLIPU_VL?

ClipOpc = RISCVISD::VNCLIPU_VL;
else if ((Val = DetectSSatPattern(Src)))
else if (N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL_SSAT)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is TRUNCATE_VECTOR_VL_SSAT the same as VNCLIP_VL?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nevermind. It's not the same because VNCLIP has a shift amount.

@@ -804,6 +804,9 @@ enum NodeType {

/// TRUNCATE - Completely drop the high bits.
TRUNCATE,
/// TRUNCATE_[SU]SAT - Truncate for saturated operand
TRUNCATE_SSAT,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

setTargetDAGCombine(ISD::TRUNCATE);
setTargetDAGCombine(ISD::TRUNCATE_SSAT);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect these 2 lines aren't needed.

@topperc
Copy link
Collaborator

topperc commented Jul 23, 2024

I just submitted #100173 to remove the RISCVISD::VNCLIP* opcodes in favor of the TRUNCATE_VECTOR_VL_*SAT nodes added in this patch. This should allow us to remove the translating from TRUNCATE_VECTOR_VL_SAT to RISCVISD::VNCLIP.

topperc added a commit that referenced this pull request Jul 23, 2024
…USAT opcodes (#100173)

These new opcodes drop the shift amount, rounding mode, and passthru.
Making them exactly like TRUNCATE_VECTOR_VL. The shift amount, rounding
mode, and passthru are added in isel patterns similar to how we
translate TRUNCATE_VECTOR_VL to vnsrl with a shift of 0.

This should simplify #99418 a little.
yuxuanchen1997 pushed a commit that referenced this pull request Jul 25, 2024
…USAT opcodes (#100173)

Summary:
These new opcodes drop the shift amount, rounding mode, and passthru.
Making them exactly like TRUNCATE_VECTOR_VL. The shift amount, rounding
mode, and passthru are added in isel patterns similar to how we
translate TRUNCATE_VECTOR_VL to vnsrl with a shift of 0.

This should simplify #99418 a little.

Test Plan: 

Reviewers: 

Subscribers: 

Tasks: 

Tags: 


Differential Revision: https://phabricator.intern.facebook.com/D60251265
Copy link

github-actions bot commented Jul 26, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@@ -292,15 +292,15 @@ entry:

; Test the (concat_vectors (X), (trunc(umin(smax(Y, 0), 2^n))))) pattern.

; TODO: %min is a value between 0 and 255 and is within the unsigned range of i8.
; So it is saturated truncate. we have an optimization opportunity.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm pretty sure this should be matching a sqxtun2, that was the intent. I believe because the lower limit is already clamped, that the upper smin is equivalent to umin.
https://godbolt.org/z/e7ne31TYb
You can see in that example that the midend turned smin into umin.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(right place :))
agree. I also added a comment with that intention and I'm thinking maybe adding code to DAGCombiner can solve it.
do you think I should include this in this patch as well?

SDValue DAGCombiner::visitTRUNCATE_USAT(SDNode *N) {
EVT VT = N->getValueType(0);
SDValue N0 = N->getOperand(0);
SDValue FPInstr = N0.getOpcode() == ISD::SMAX ? N0.getOperand(0) : N0;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need to check the other operand of this SMAX?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, missed it. I'll fix it.

SDValue Sat = DAG.getNode(ISD::FP_TO_UINT_SAT, SDLoc(FPInstr), VT,
FPInstr.getOperand(0),
DAG.getValueType(VT.getScalarType()));
return Sat;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return DAG.getNode... no need for temporary variable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how I can do it. could you give me little advise please?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@topperc how about now? do you think I'm doing properly?

// fold satruated truncate
if (SDValue SaturatedTR = foldToSaturated(N, VT, N0, SrcVT, DL, TLI, DAG)) {
return SaturatedTR;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drop curly braces

@@ -814,6 +814,13 @@ enum NodeType {

/// TRUNCATE - Completely drop the high bits.
TRUNCATE,
/// TRUNCATE_[SU]SAT - Truncate for saturated operand
TRUNCATE_SSAT_S, // saturate signed input to signed result -
// truncate(smin(smax(x)))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please spell out the parameters for my eyes? [s|u][min|max] have two parameters.

(UQXTNv8i8 V128:$Vn)>;
// trunc(umin(X, 65535)) -> UQXTRN v4i16
def : Pat<(v4i16 (trunc (umin (v4i32 V128:$Vn), (v4i32 VImmFFFF)))),
def : Pat<(v4i16 (truncusat_u (v4i32 V128:$Vn))),
(UQXTNv4i16 V128:$Vn)>;
// trunc(smin(smax(X, -128), 128)) -> SQXTRN
// with reversed min/max
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The "reversed min/max" lines can be removed now.

Copy link
Collaborator

@davemgreen davemgreen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I wasn't around yesterday, thanks for the updates. The code looks good to me now, if others agree.

@ParkHanbum
Copy link
Contributor Author

@RKSimon @davemgreen Is it okay to reduce commits?

@RKSimon
Copy link
Collaborator

RKSimon commented Aug 9, 2024

@RKSimon @davemgreen Is it okay to reduce commits?

Sure - as I mentioned above the commit summary at the top needs editing. Once that's done I'll approve.

A truncate is considered saturated if no additional conversion is
required between the target and return values. If the target is
saturated when attempting to truncate from a vector, there is an
opportunity to optimize it.

Previously, each architecture had its own attempt at optimization,
leading to redundant code. This patch implements common logic by
introducing three new ISDs:

 `ISD::TRUNCATE_SSAT_S`: When the operand is a signed value and
 the range of values matches the range of signed values of the
 destination type.

 `ISD::TRUNCATE_SSAT_U`: When the operand is a signed value and
 the range of values matches the range of unsigned values of the
 destination type.

 `ISD::TRUNCATE_USAT_U`: When the operand is an unsigned value and
 the range of values matches the range of unsigned values of the
 destination type.

These ISDs indicate a saturated truncate.

Fixes llvm#85903
RKSimon added a commit that referenced this pull request Aug 9, 2024
Inspired by #99418 (which hopefully we can replace this code with at some point)
@davemgreen
Copy link
Collaborator

Can you update the description in github with the one in 8d81896 (or the others if you want to combine them)? LLVM uses a squash-and-merge approach to committing PR's, so they will be squashed into a single commit.

@ParkHanbum
Copy link
Contributor Author

do you mean my written message when I request PR? if it is, I done.

Copy link
Collaborator

@RKSimon RKSimon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one final minor

@@ -1908,6 +1909,8 @@ SDValue DAGCombiner::visit(SDNode *N) {
case ISD::ZERO_EXTEND_VECTOR_INREG:
case ISD::ANY_EXTEND_VECTOR_INREG: return visitEXTEND_VECTOR_INREG(N);
case ISD::TRUNCATE: return visitTRUNCATE(N);
case ISD::TRUNCATE_SSAT_U:
case ISD::TRUNCATE_USAT_U: return visitTRUNCATE_USAT(N);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be called visitTRUNCATE_SAT_U? Or should we just have a visitTRUNCATE_SAT call and handle the unsigned cases inside it?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did we not mean to change this to just case ISD::TRUNCATE_USAT_U: return visitTRUNCATE_USAT(N);, without the TRUNCATE_SSAT_U case? (Sorry if I missed that). Otherwise it will change TRUNCATE_SSAT_U(FP_TO_UINT(x)) to FP_TO_UINT_SAT(x), which will not clamp to the same bounds.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

visitTRUNCATE_USAT has a small task, so I don't see the need to separate SSAT and USAT.
Is it LLVM's way to reduce unnecessary separation and separate them if they are reasonably large?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, we can generalize it later if the need arises - but we need to confirm if we should be handling the TRUNCATE_SSAT_U case or not (do we have test coverage?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If that's the case, then we don't have a test for truncate_ssat_u. As @davemgreen commented, it is right to change to call visitTRUNCATE_USAT() in case TRUNCATE_USAT_U.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should I change it so that only case ISD::TRUNCATE_USAT_U: return visitTRUNCATE_USAT(N); remains?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so. Perhaps call it visitTRUNCATE_USAT_U too?

Add support for saturated truncate with the following changes:

- Add action to Legal for types v8i16, v4i32, and v2i64
- Implement `isTypeDesirableForOp` to check for truncate conversions
- Add patterns for saturated truncate of supported types
Add support for saturated truncate by implementing the following changes:

- Add `TRUNCATE_[SU]SAT_[SU]` to the Action target of `TRUNCATE`
- Add `TRUNCATE_[SU]SAT_[SU]` to the TargetLowering target of `TRUNCATE`
- Convert `TRUNCATE_SSAT_S` to `TRUNCATE_VECTOR_VL_SSAT`
- Convert `TRUNCATE_[SU]SAT_U` to `TRUNCATE_VECTOR_VL_USAT`
Copy link
Collaborator

@RKSimon RKSimon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Collaborator

@davemgreen davemgreen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. LGTM

@RKSimon
Copy link
Collaborator

RKSimon commented Aug 13, 2024

@ParkHanbum Are you happy for me to commit this?

@ParkHanbum
Copy link
Contributor Author

ParkHanbum commented Aug 13, 2024

@RKSimon I feel very honored, Sir! please do it!

@ParkHanbum ParkHanbum requested a review from RKSimon August 13, 2024 18:41
Copy link
Collaborator

@RKSimon RKSimon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@RKSimon RKSimon merged commit 0d074ba into llvm:main Aug 14, 2024
6 of 8 checks passed
unsigned NewOpc;
if (Opc == ISD::TRUNCATE_SSAT_S)
NewOpc = RISCVISD::TRUNCATE_VECTOR_VL_SSAT;
else if (Opc == ISD::TRUNCATE_SSAT_U || Opc == ISD::TRUNCATE_USAT_U)
Copy link
Collaborator

@topperc topperc Aug 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How are we lowering 2 different opcodes to the same RISC-V instruction?

@ParkHanbum ParkHanbum deleted the truncate_sat branch October 16, 2024 10:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[SelectionDAG] Add a new ISD Node for vector saturating truncation
6 participants