-
Notifications
You must be signed in to change notification settings - Fork 12.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DAG] Support saturated truncate #99418
Conversation
@llvm/pr-subscribers-backend-risc-v @llvm/pr-subscribers-backend-aarch64 Author: hanbeom (ParkHanbum) Changes
previously, each architecture had an attemping optimization, so there this patch implements common logic by adding Full diff: https://github.com/llvm/llvm-project/pull/99418.diff 9 Files Affected:
diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index e6b10209b4767..0b36e5b40da73 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -804,6 +804,9 @@ enum NodeType {
/// TRUNCATE - Completely drop the high bits.
TRUNCATE,
+ /// TRUNCATE_[SU]SAT - Truncate for saturated operand
+ TRUNCATE_SSAT,
+ TRUNCATE_USAT,
/// [SU]INT_TO_FP - These operators convert integers (whose interpreted sign
/// depends on the first letter) to floating point.
diff --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td
index 133c9b113e51b..a5242694c9507 100644
--- a/llvm/include/llvm/Target/TargetSelectionDAG.td
+++ b/llvm/include/llvm/Target/TargetSelectionDAG.td
@@ -471,6 +471,8 @@ def sext : SDNode<"ISD::SIGN_EXTEND", SDTIntExtendOp>;
def zext : SDNode<"ISD::ZERO_EXTEND", SDTIntExtendOp>;
def anyext : SDNode<"ISD::ANY_EXTEND" , SDTIntExtendOp>;
def trunc : SDNode<"ISD::TRUNCATE" , SDTIntTruncOp>;
+def truncssat : SDNode<"ISD::TRUNCATE_SSAT", SDTIntTruncOp>;
+def truncusat : SDNode<"ISD::TRUNCATE_USAT", SDTIntTruncOp>;
def bitconvert : SDNode<"ISD::BITCAST" , SDTUnaryOp>;
def addrspacecast : SDNode<"ISD::ADDRSPACECAST", SDTUnaryOp>;
def freeze : SDNode<"ISD::FREEZE" , SDTFreeze>;
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 302ad128f4f53..967f313c9885e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -486,6 +486,8 @@ namespace {
SDValue visitSIGN_EXTEND_INREG(SDNode *N);
SDValue visitEXTEND_VECTOR_INREG(SDNode *N);
SDValue visitTRUNCATE(SDNode *N);
+ SDValue visitTRUNCATE_SSAT(SDNode *N);
+ SDValue visitTRUNCATE_USAT(SDNode *N);
SDValue visitBITCAST(SDNode *N);
SDValue visitFREEZE(SDNode *N);
SDValue visitBUILD_PAIR(SDNode *N);
@@ -1907,6 +1909,8 @@ SDValue DAGCombiner::visit(SDNode *N) {
case ISD::ZERO_EXTEND_VECTOR_INREG:
case ISD::ANY_EXTEND_VECTOR_INREG: return visitEXTEND_VECTOR_INREG(N);
case ISD::TRUNCATE: return visitTRUNCATE(N);
+ case ISD::TRUNCATE_SSAT: return visitTRUNCATE_SSAT(N);
+ case ISD::TRUNCATE_USAT: return visitTRUNCATE_USAT(N);
case ISD::BITCAST: return visitBITCAST(N);
case ISD::BUILD_PAIR: return visitBUILD_PAIR(N);
case ISD::FADD: return visitFADD(N);
@@ -13154,7 +13158,8 @@ SDValue DAGCombiner::matchVSelectOpSizesWithSetCC(SDNode *Cast) {
unsigned CastOpcode = Cast->getOpcode();
assert((CastOpcode == ISD::SIGN_EXTEND || CastOpcode == ISD::ZERO_EXTEND ||
CastOpcode == ISD::TRUNCATE || CastOpcode == ISD::FP_EXTEND ||
- CastOpcode == ISD::FP_ROUND) &&
+ CastOpcode == ISD::TRUNCATE_SSAT ||
+ CastOpcode == ISD::TRUNCATE_USAT || CastOpcode == ISD::FP_ROUND) &&
"Unexpected opcode for vector select narrowing/widening");
// We only do this transform before legal ops because the pattern may be
@@ -14867,6 +14872,119 @@ SDValue DAGCombiner::visitEXTEND_VECTOR_INREG(SDNode *N) {
return SDValue();
}
+SDValue DAGCombiner::visitTRUNCATE_USAT(SDNode *N) {
+ EVT VT = N->getValueType(0);
+ SDValue N0 = N->getOperand(0);
+ SDValue FPInstr = N0.getOpcode() == ISD::SMAX ? N0.getOperand(0) : N0;
+ if (FPInstr.getOpcode() == ISD::FP_TO_SINT ||
+ FPInstr.getOpcode() == ISD::FP_TO_UINT) {
+ EVT FPVT = FPInstr.getOperand(0).getValueType();
+ if (!DAG.getTargetLoweringInfo().shouldConvertFpToSat(ISD::FP_TO_UINT_SAT,
+ FPVT, VT))
+ return SDValue();
+ SDValue Sat = DAG.getNode(ISD::FP_TO_UINT_SAT, SDLoc(FPInstr), VT,
+ FPInstr.getOperand(0),
+ DAG.getValueType(VT.getScalarType()));
+ return Sat;
+ }
+
+ return SDValue();
+}
+
+SDValue DAGCombiner::visitTRUNCATE_SSAT(SDNode *N) { return SDValue(); }
+
+/// Detect patterns of truncation with unsigned saturation:
+///
+/// 1. (truncate (umin (x, unsigned_max_of_dest_type)) to dest_type).
+/// Return the source value x to be truncated or SDValue() if the pattern was
+/// not matched.
+///
+/// 2. (truncate (smin (smax (x, C1), C2)) to dest_type),
+/// where C1 >= 0 and C2 is unsigned max of destination type.
+///
+/// (truncate (smax (smin (x, C2), C1)) to dest_type)
+/// where C1 >= 0, C2 is unsigned max of destination type and C1 <= C2.
+///
+/// These two patterns are equivalent to:
+/// (truncate (umin (smax(x, C1), unsigned_max_of_dest_type)) to dest_type)
+/// So return the smax(x, C1) value to be truncated or SDValue() if the
+/// pattern was not matched.
+static SDValue detectUSatPattern(SDValue In, EVT VT, SelectionDAG &DAG,
+ const SDLoc &DL) {
+ EVT InVT = In.getValueType();
+
+ // Saturation with truncation. We truncate from InVT to VT.
+ assert(InVT.getScalarSizeInBits() > VT.getScalarSizeInBits() &&
+ "Unexpected types for truncate operation");
+
+ // Match min/max and return limit value as a parameter.
+ auto MatchMinMax = [](SDValue V, unsigned Opcode, APInt &Limit) -> SDValue {
+ if (V.getOpcode() == Opcode &&
+ ISD::isConstantSplatVector(V.getOperand(1).getNode(), Limit))
+ return V.getOperand(0);
+ return SDValue();
+ };
+
+ APInt C1, C2;
+ if (SDValue UMin = MatchMinMax(In, ISD::UMIN, C2))
+ // C2 should be equal to UINT32_MAX / UINT16_MAX / UINT8_MAX according
+ // the element size of the destination type.
+ if (C2.isMask(VT.getScalarSizeInBits()))
+ return UMin;
+
+ if (SDValue SMin = MatchMinMax(In, ISD::SMIN, C2))
+ if (MatchMinMax(SMin, ISD::SMAX, C1))
+ if (C1.isNonNegative() && C2.isMask(VT.getScalarSizeInBits()))
+ return SMin;
+
+ if (SDValue SMax = MatchMinMax(In, ISD::SMAX, C1))
+ if (SDValue SMin = MatchMinMax(SMax, ISD::SMIN, C2))
+ if (C1.isNonNegative() && C2.isMask(VT.getScalarSizeInBits()) &&
+ C2.uge(C1))
+ return DAG.getNode(ISD::SMAX, DL, InVT, SMin, In.getOperand(1));
+
+ return SDValue();
+}
+
+/// Detect patterns of truncation with signed saturation:
+/// (truncate (smin ((smax (x, signed_min_of_dest_type)),
+/// signed_max_of_dest_type)) to dest_type)
+/// or:
+/// (truncate (smax ((smin (x, signed_max_of_dest_type)),
+/// signed_min_of_dest_type)) to dest_type).
+/// With MatchPackUS, the smax/smin range is [0, unsigned_max_of_dest_type].
+/// Return the source value to be truncated or SDValue() if the pattern was not
+/// matched.
+static SDValue detectSSatPattern(SDValue In, EVT VT) {
+ unsigned NumDstBits = VT.getScalarSizeInBits();
+ unsigned NumSrcBits = In.getScalarValueSizeInBits();
+ assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");
+
+ auto MatchMinMax = [](SDValue V, unsigned Opcode,
+ const APInt &Limit) -> SDValue {
+ APInt C;
+ if (V.getOpcode() == Opcode &&
+ ISD::isConstantSplatVector(V.getOperand(1).getNode(), C) && C == Limit)
+ return V.getOperand(0);
+ return SDValue();
+ };
+
+ APInt SignedMax, SignedMin;
+ SignedMax = APInt::getSignedMaxValue(NumDstBits).sext(NumSrcBits);
+ SignedMin = APInt::getSignedMinValue(NumDstBits).sext(NumSrcBits);
+ if (SDValue SMin = MatchMinMax(In, ISD::SMIN, SignedMax)) {
+ if (SDValue SMax = MatchMinMax(SMin, ISD::SMAX, SignedMin)) {
+ return SMax;
+ }
+ }
+ if (SDValue SMax = MatchMinMax(In, ISD::SMAX, SignedMin)) {
+ if (SDValue SMin = MatchMinMax(SMax, ISD::SMIN, SignedMax)) {
+ return SMin;
+ }
+ }
+ return SDValue();
+}
+
SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
SDValue N0 = N->getOperand(0);
EVT VT = N->getValueType(0);
@@ -14874,6 +14992,18 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
bool isLE = DAG.getDataLayout().isLittleEndian();
SDLoc DL(N);
+ if (!LegalOperations && N->getOpcode() == ISD::TRUNCATE) {
+ if (TLI.isOperationLegalOrCustom(ISD::TRUNCATE_SSAT, SrcVT)) {
+ if (SDValue SSatVal = detectSSatPattern(N0, VT))
+ return DAG.getNode(ISD::TRUNCATE_SSAT, DL, VT, SSatVal);
+ }
+
+ if (TLI.isOperationLegalOrCustom(ISD::TRUNCATE_USAT, SrcVT)) {
+ if (SDValue USatVal = detectUSatPattern(N0, VT, DAG, DL))
+ return DAG.getNode(ISD::TRUNCATE_USAT, DL, VT, USatVal);
+ }
+ }
+
// trunc(undef) = undef
if (N0.isUndef())
return DAG.getUNDEF(VT);
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
index cc8de3a217f82..d3ad6c8acf4f1 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
@@ -380,6 +380,8 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
case ISD::SIGN_EXTEND_VECTOR_INREG: return "sign_extend_vector_inreg";
case ISD::ZERO_EXTEND_VECTOR_INREG: return "zero_extend_vector_inreg";
case ISD::TRUNCATE: return "truncate";
+ case ISD::TRUNCATE_SSAT: return "truncate_ssat";
+ case ISD::TRUNCATE_USAT: return "truncate_usat";
case ISD::FP_ROUND: return "fp_round";
case ISD::STRICT_FP_ROUND: return "strict_fp_round";
case ISD::FP_EXTEND: return "fp_extend";
diff --git a/llvm/lib/CodeGen/TargetLoweringBase.cpp b/llvm/lib/CodeGen/TargetLoweringBase.cpp
index bf031c00a2449..3e855d5e450df 100644
--- a/llvm/lib/CodeGen/TargetLoweringBase.cpp
+++ b/llvm/lib/CodeGen/TargetLoweringBase.cpp
@@ -718,6 +718,10 @@ void TargetLoweringBase::initActions() {
// Absolute difference
setOperationAction({ISD::ABDS, ISD::ABDU}, VT, Expand);
+ // Saturated trunc
+ setOperationAction(ISD::TRUNCATE_SSAT, VT, Expand);
+ setOperationAction(ISD::TRUNCATE_USAT, VT, Expand);
+
// These default to Expand so they will be expanded to CTLZ/CTTZ by default.
setOperationAction({ISD::CTLZ_ZERO_UNDEF, ISD::CTTZ_ZERO_UNDEF}, VT,
Expand);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index df9b0ae1a632f..504bbaed1c8aa 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1274,6 +1274,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::AVGCEILU, VT, Legal);
setOperationAction(ISD::ABDS, VT, Legal);
setOperationAction(ISD::ABDU, VT, Legal);
+ setOperationAction(ISD::TRUNCATE_SSAT, VT, Legal);
+ setOperationAction(ISD::TRUNCATE_USAT, VT, Legal);
}
// Vector reductions
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index dd11f74882115..322219607407b 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -5343,9 +5343,13 @@ def VImm8000: PatLeaf<(AArch64mvni_msl (i32 127), (i32 264))>;
// trunc(umin(X, 255)) -> UQXTRN v8i8
def : Pat<(v8i8 (trunc (umin (v8i16 V128:$Vn), (v8i16 VImmFF)))),
(UQXTNv8i8 V128:$Vn)>;
+def : Pat<(v8i8 (truncusat (v8i16 V128:$Vn))),
+ (UQXTNv8i8 V128:$Vn)>;
// trunc(umin(X, 65535)) -> UQXTRN v4i16
def : Pat<(v4i16 (trunc (umin (v4i32 V128:$Vn), (v4i32 VImmFFFF)))),
(UQXTNv4i16 V128:$Vn)>;
+def : Pat<(v4i16 (truncusat (v4i32 V128:$Vn))),
+ (UQXTNv4i16 V128:$Vn)>;
// trunc(smin(smax(X, -128), 128)) -> SQXTRN
// with reversed min/max
def : Pat<(v8i8 (trunc (smin (smax (v8i16 V128:$Vn), (v8i16 VImm80)),
@@ -5354,6 +5358,8 @@ def : Pat<(v8i8 (trunc (smin (smax (v8i16 V128:$Vn), (v8i16 VImm80)),
def : Pat<(v8i8 (trunc (smax (smin (v8i16 V128:$Vn), (v8i16 VImm7F)),
(v8i16 VImm80)))),
(SQXTNv8i8 V128:$Vn)>;
+def : Pat<(v8i8 (truncssat (v8i16 V128:$Vn))),
+ (SQXTNv8i8 V128:$Vn)>;
// trunc(smin(smax(X, -32768), 32767)) -> SQXTRN
// with reversed min/max
def : Pat<(v4i16 (trunc (smin (smax (v4i32 V128:$Vn), (v4i32 VImm8000)),
@@ -5362,6 +5368,8 @@ def : Pat<(v4i16 (trunc (smin (smax (v4i32 V128:$Vn), (v4i32 VImm8000)),
def : Pat<(v4i16 (trunc (smax (smin (v4i32 V128:$Vn), (v4i32 VImm7FFF)),
(v4i32 VImm8000)))),
(SQXTNv4i16 V128:$Vn)>;
+def : Pat<(v4i16 (truncssat (v4i32 V128:$Vn))),
+ (SQXTNv4i16 V128:$Vn)>;
// concat_vectors(Vd, trunc(smin(smax Vm, -128), 127) ~> SQXTN2(Vd, Vn)
// with reversed min/max
@@ -5375,6 +5383,10 @@ def : Pat<(v16i8 (concat_vectors
(v8i8 (trunc (smax (smin (v8i16 V128:$Vn), (v8i16 VImm7F)),
(v8i16 VImm80)))))),
(SQXTNv16i8 (INSERT_SUBREG (IMPLICIT_DEF), V64:$Vd, dsub), V128:$Vn)>;
+def : Pat<(v16i8 (concat_vectors
+ (v8i8 V64:$Vd),
+ (v8i8 (truncssat (v8i16 V128:$Vn))))),
+ (SQXTNv16i8 (INSERT_SUBREG (IMPLICIT_DEF), V64:$Vd, dsub), V128:$Vn)>;
// concat_vectors(Vd, trunc(smin(smax Vm, -32768), 32767) ~> SQXTN2(Vd, Vn)
// with reversed min/max
@@ -5388,6 +5400,10 @@ def : Pat<(v8i16 (concat_vectors
(v4i16 (trunc (smax (smin (v4i32 V128:$Vn), (v4i32 VImm7FFF)),
(v4i32 VImm8000)))))),
(SQXTNv8i16 (INSERT_SUBREG (IMPLICIT_DEF), V64:$Vd, dsub), V128:$Vn)>;
+def : Pat<(v8i16 (concat_vectors
+ (v4i16 V64:$Vd),
+ (v4i16 (truncssat (v4i32 V128:$Vn))))),
+ (SQXTNv8i16 (INSERT_SUBREG (IMPLICIT_DEF), V64:$Vd, dsub), V128:$Vn)>;
// Select BSWAP vector instructions into REV instructions
def : Pat<(v4i16 (bswap (v4i16 V64:$Rn))),
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 953196a586b6e..3b54416d1b5b2 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -853,7 +853,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
// Integer VTs are lowered as a series of "RISCVISD::TRUNCATE_VECTOR_VL"
// nodes which truncate by one power of two at a time.
- setOperationAction(ISD::TRUNCATE, VT, Custom);
+ setOperationAction(
+ {ISD::TRUNCATE, ISD::TRUNCATE_SSAT, ISD::TRUNCATE_USAT}, VT, Custom);
// Custom-lower insert/extract operations to simplify patterns.
setOperationAction({ISD::INSERT_VECTOR_ELT, ISD::EXTRACT_VECTOR_ELT}, VT,
@@ -1168,7 +1169,9 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setOperationAction(ISD::SELECT, VT, Custom);
- setOperationAction(ISD::TRUNCATE, VT, Custom);
+ setOperationAction(
+ {ISD::TRUNCATE, ISD::TRUNCATE_SSAT, ISD::TRUNCATE_USAT}, VT,
+ Custom);
setOperationAction(ISD::BITCAST, VT, Custom);
@@ -1479,8 +1482,11 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setTargetDAGCombine({ISD::UMAX, ISD::UMIN, ISD::SMAX, ISD::SMIN});
if ((Subtarget.hasStdExtZbs() && Subtarget.is64Bit()) ||
- Subtarget.hasStdExtV())
+ Subtarget.hasStdExtV()) {
setTargetDAGCombine(ISD::TRUNCATE);
+ setTargetDAGCombine(ISD::TRUNCATE_SSAT);
+ setTargetDAGCombine(ISD::TRUNCATE_USAT);
+ }
if (Subtarget.hasStdExtZbkb())
setTargetDAGCombine(ISD::BITREVERSE);
@@ -6092,7 +6098,7 @@ static bool hasMergeOp(unsigned Opcode) {
Opcode <= RISCVISD::LAST_RISCV_STRICTFP_OPCODE &&
"not a RISC-V target specific op");
static_assert(RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP ==
- 130 &&
+ 132 &&
RISCVISD::LAST_RISCV_STRICTFP_OPCODE -
ISD::FIRST_TARGET_STRICTFP_OPCODE ==
21 &&
@@ -6118,7 +6124,7 @@ static bool hasMaskOp(unsigned Opcode) {
Opcode <= RISCVISD::LAST_RISCV_STRICTFP_OPCODE &&
"not a RISC-V target specific op");
static_assert(RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP ==
- 130 &&
+ 132 &&
RISCVISD::LAST_RISCV_STRICTFP_OPCODE -
ISD::FIRST_TARGET_STRICTFP_OPCODE ==
21 &&
@@ -6389,6 +6395,8 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
return DAG.getNode(RISCVISD::BREV8, DL, VT, BSwap);
}
case ISD::TRUNCATE:
+ case ISD::TRUNCATE_SSAT:
+ case ISD::TRUNCATE_USAT:
// Only custom-lower vector truncates
if (!Op.getSimpleValueType().isVector())
return Op;
@@ -8275,11 +8283,15 @@ SDValue RISCVTargetLowering::lowerVectorTruncLike(SDValue Op,
LLVMContext &Context = *DAG.getContext();
const ElementCount Count = ContainerVT.getVectorElementCount();
+ unsigned NewOpc = RISCVISD::TRUNCATE_VECTOR_VL;
+ if (Op.getOpcode() == ISD::TRUNCATE_SSAT)
+ NewOpc = RISCVISD::TRUNCATE_VECTOR_VL_SSAT;
+ else if (Op.getOpcode() == ISD::TRUNCATE_USAT)
+ NewOpc = RISCVISD::TRUNCATE_VECTOR_VL_USAT;
do {
SrcEltVT = MVT::getIntegerVT(SrcEltVT.getSizeInBits() / 2);
EVT ResultVT = EVT::getVectorVT(Context, SrcEltVT, Count);
- Result = DAG.getNode(RISCVISD::TRUNCATE_VECTOR_VL, DL, ResultVT, Result,
- Mask, VL);
+ Result = DAG.getNode(NewOpc, DL, ResultVT, Result, Mask, VL);
} while (SrcEltVT != DstEltVT);
if (SrcVT.isFixedLengthVector())
@@ -16512,7 +16524,9 @@ static SDValue combineTruncOfSraSext(SDNode *N, SelectionDAG &DAG) {
// minimum value.
static SDValue combineTruncToVnclip(SDNode *N, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
- assert(N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL);
+ assert(N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL ||
+ N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL_SSAT ||
+ N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL_USAT);
MVT VT = N->getSimpleValueType(0);
@@ -16617,9 +16631,11 @@ static SDValue combineTruncToVnclip(SDNode *N, SelectionDAG &DAG,
SDValue Val;
unsigned ClipOpc;
- if ((Val = DetectUSatPattern(Src)))
+
+ Val = N->getOperand(0);
+ if (N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL_USAT)
ClipOpc = RISCVISD::VNCLIPU_VL;
- else if ((Val = DetectSSatPattern(Src)))
+ else if (N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL_SSAT)
ClipOpc = RISCVISD::VNCLIP_VL;
else
return SDValue();
@@ -16857,6 +16873,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
}
return SDValue();
case RISCVISD::TRUNCATE_VECTOR_VL:
+ case RISCVISD::TRUNCATE_VECTOR_VL_SSAT:
+ case RISCVISD::TRUNCATE_VECTOR_VL_USAT:
if (SDValue V = combineTruncOfSraSext(N, DAG))
return V;
return combineTruncToVnclip(N, DAG, Subtarget);
@@ -20433,6 +20451,8 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(SPLAT_VECTOR_SPLIT_I64_VL)
NODE_NAME_CASE(READ_VLENB)
NODE_NAME_CASE(TRUNCATE_VECTOR_VL)
+ NODE_NAME_CASE(TRUNCATE_VECTOR_VL_SSAT)
+ NODE_NAME_CASE(TRUNCATE_VECTOR_VL_USAT)
NODE_NAME_CASE(VSLIDEUP_VL)
NODE_NAME_CASE(VSLIDE1UP_VL)
NODE_NAME_CASE(VSLIDEDOWN_VL)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 0b0ad9229f0b3..3d582fcdaf64b 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -181,6 +181,8 @@ enum NodeType : unsigned {
// Truncates a RVV integer vector by one power-of-two. Carries both an extra
// mask and VL operand.
TRUNCATE_VECTOR_VL,
+ TRUNCATE_VECTOR_VL_SSAT,
+ TRUNCATE_VECTOR_VL_USAT,
// Matches the semantics of vslideup/vslidedown. The first operand is the
// pass-thru operand, the second is the source vector, the third is the XLenVT
// index (either constant or non-constant), the fourth is the mask, the fifth
|
@llvm/pr-subscribers-llvm-selectiondag Author: hanbeom (ParkHanbum) Changes
previously, each architecture had an attemping optimization, so there this patch implements common logic by adding Full diff: https://github.com/llvm/llvm-project/pull/99418.diff 9 Files Affected:
diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index e6b10209b4767..0b36e5b40da73 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -804,6 +804,9 @@ enum NodeType {
/// TRUNCATE - Completely drop the high bits.
TRUNCATE,
+ /// TRUNCATE_[SU]SAT - Truncate for saturated operand
+ TRUNCATE_SSAT,
+ TRUNCATE_USAT,
/// [SU]INT_TO_FP - These operators convert integers (whose interpreted sign
/// depends on the first letter) to floating point.
diff --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td
index 133c9b113e51b..a5242694c9507 100644
--- a/llvm/include/llvm/Target/TargetSelectionDAG.td
+++ b/llvm/include/llvm/Target/TargetSelectionDAG.td
@@ -471,6 +471,8 @@ def sext : SDNode<"ISD::SIGN_EXTEND", SDTIntExtendOp>;
def zext : SDNode<"ISD::ZERO_EXTEND", SDTIntExtendOp>;
def anyext : SDNode<"ISD::ANY_EXTEND" , SDTIntExtendOp>;
def trunc : SDNode<"ISD::TRUNCATE" , SDTIntTruncOp>;
+def truncssat : SDNode<"ISD::TRUNCATE_SSAT", SDTIntTruncOp>;
+def truncusat : SDNode<"ISD::TRUNCATE_USAT", SDTIntTruncOp>;
def bitconvert : SDNode<"ISD::BITCAST" , SDTUnaryOp>;
def addrspacecast : SDNode<"ISD::ADDRSPACECAST", SDTUnaryOp>;
def freeze : SDNode<"ISD::FREEZE" , SDTFreeze>;
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 302ad128f4f53..967f313c9885e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -486,6 +486,8 @@ namespace {
SDValue visitSIGN_EXTEND_INREG(SDNode *N);
SDValue visitEXTEND_VECTOR_INREG(SDNode *N);
SDValue visitTRUNCATE(SDNode *N);
+ SDValue visitTRUNCATE_SSAT(SDNode *N);
+ SDValue visitTRUNCATE_USAT(SDNode *N);
SDValue visitBITCAST(SDNode *N);
SDValue visitFREEZE(SDNode *N);
SDValue visitBUILD_PAIR(SDNode *N);
@@ -1907,6 +1909,8 @@ SDValue DAGCombiner::visit(SDNode *N) {
case ISD::ZERO_EXTEND_VECTOR_INREG:
case ISD::ANY_EXTEND_VECTOR_INREG: return visitEXTEND_VECTOR_INREG(N);
case ISD::TRUNCATE: return visitTRUNCATE(N);
+ case ISD::TRUNCATE_SSAT: return visitTRUNCATE_SSAT(N);
+ case ISD::TRUNCATE_USAT: return visitTRUNCATE_USAT(N);
case ISD::BITCAST: return visitBITCAST(N);
case ISD::BUILD_PAIR: return visitBUILD_PAIR(N);
case ISD::FADD: return visitFADD(N);
@@ -13154,7 +13158,8 @@ SDValue DAGCombiner::matchVSelectOpSizesWithSetCC(SDNode *Cast) {
unsigned CastOpcode = Cast->getOpcode();
assert((CastOpcode == ISD::SIGN_EXTEND || CastOpcode == ISD::ZERO_EXTEND ||
CastOpcode == ISD::TRUNCATE || CastOpcode == ISD::FP_EXTEND ||
- CastOpcode == ISD::FP_ROUND) &&
+ CastOpcode == ISD::TRUNCATE_SSAT ||
+ CastOpcode == ISD::TRUNCATE_USAT || CastOpcode == ISD::FP_ROUND) &&
"Unexpected opcode for vector select narrowing/widening");
// We only do this transform before legal ops because the pattern may be
@@ -14867,6 +14872,119 @@ SDValue DAGCombiner::visitEXTEND_VECTOR_INREG(SDNode *N) {
return SDValue();
}
+SDValue DAGCombiner::visitTRUNCATE_USAT(SDNode *N) {
+ EVT VT = N->getValueType(0);
+ SDValue N0 = N->getOperand(0);
+ SDValue FPInstr = N0.getOpcode() == ISD::SMAX ? N0.getOperand(0) : N0;
+ if (FPInstr.getOpcode() == ISD::FP_TO_SINT ||
+ FPInstr.getOpcode() == ISD::FP_TO_UINT) {
+ EVT FPVT = FPInstr.getOperand(0).getValueType();
+ if (!DAG.getTargetLoweringInfo().shouldConvertFpToSat(ISD::FP_TO_UINT_SAT,
+ FPVT, VT))
+ return SDValue();
+ SDValue Sat = DAG.getNode(ISD::FP_TO_UINT_SAT, SDLoc(FPInstr), VT,
+ FPInstr.getOperand(0),
+ DAG.getValueType(VT.getScalarType()));
+ return Sat;
+ }
+
+ return SDValue();
+}
+
+SDValue DAGCombiner::visitTRUNCATE_SSAT(SDNode *N) { return SDValue(); }
+
+/// Detect patterns of truncation with unsigned saturation:
+///
+/// 1. (truncate (umin (x, unsigned_max_of_dest_type)) to dest_type).
+/// Return the source value x to be truncated or SDValue() if the pattern was
+/// not matched.
+///
+/// 2. (truncate (smin (smax (x, C1), C2)) to dest_type),
+/// where C1 >= 0 and C2 is unsigned max of destination type.
+///
+/// (truncate (smax (smin (x, C2), C1)) to dest_type)
+/// where C1 >= 0, C2 is unsigned max of destination type and C1 <= C2.
+///
+/// These two patterns are equivalent to:
+/// (truncate (umin (smax(x, C1), unsigned_max_of_dest_type)) to dest_type)
+/// So return the smax(x, C1) value to be truncated or SDValue() if the
+/// pattern was not matched.
+static SDValue detectUSatPattern(SDValue In, EVT VT, SelectionDAG &DAG,
+ const SDLoc &DL) {
+ EVT InVT = In.getValueType();
+
+ // Saturation with truncation. We truncate from InVT to VT.
+ assert(InVT.getScalarSizeInBits() > VT.getScalarSizeInBits() &&
+ "Unexpected types for truncate operation");
+
+ // Match min/max and return limit value as a parameter.
+ auto MatchMinMax = [](SDValue V, unsigned Opcode, APInt &Limit) -> SDValue {
+ if (V.getOpcode() == Opcode &&
+ ISD::isConstantSplatVector(V.getOperand(1).getNode(), Limit))
+ return V.getOperand(0);
+ return SDValue();
+ };
+
+ APInt C1, C2;
+ if (SDValue UMin = MatchMinMax(In, ISD::UMIN, C2))
+ // C2 should be equal to UINT32_MAX / UINT16_MAX / UINT8_MAX according
+ // the element size of the destination type.
+ if (C2.isMask(VT.getScalarSizeInBits()))
+ return UMin;
+
+ if (SDValue SMin = MatchMinMax(In, ISD::SMIN, C2))
+ if (MatchMinMax(SMin, ISD::SMAX, C1))
+ if (C1.isNonNegative() && C2.isMask(VT.getScalarSizeInBits()))
+ return SMin;
+
+ if (SDValue SMax = MatchMinMax(In, ISD::SMAX, C1))
+ if (SDValue SMin = MatchMinMax(SMax, ISD::SMIN, C2))
+ if (C1.isNonNegative() && C2.isMask(VT.getScalarSizeInBits()) &&
+ C2.uge(C1))
+ return DAG.getNode(ISD::SMAX, DL, InVT, SMin, In.getOperand(1));
+
+ return SDValue();
+}
+
+/// Detect patterns of truncation with signed saturation:
+/// (truncate (smin ((smax (x, signed_min_of_dest_type)),
+/// signed_max_of_dest_type)) to dest_type)
+/// or:
+/// (truncate (smax ((smin (x, signed_max_of_dest_type)),
+/// signed_min_of_dest_type)) to dest_type).
+/// With MatchPackUS, the smax/smin range is [0, unsigned_max_of_dest_type].
+/// Return the source value to be truncated or SDValue() if the pattern was not
+/// matched.
+static SDValue detectSSatPattern(SDValue In, EVT VT) {
+ unsigned NumDstBits = VT.getScalarSizeInBits();
+ unsigned NumSrcBits = In.getScalarValueSizeInBits();
+ assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");
+
+ auto MatchMinMax = [](SDValue V, unsigned Opcode,
+ const APInt &Limit) -> SDValue {
+ APInt C;
+ if (V.getOpcode() == Opcode &&
+ ISD::isConstantSplatVector(V.getOperand(1).getNode(), C) && C == Limit)
+ return V.getOperand(0);
+ return SDValue();
+ };
+
+ APInt SignedMax, SignedMin;
+ SignedMax = APInt::getSignedMaxValue(NumDstBits).sext(NumSrcBits);
+ SignedMin = APInt::getSignedMinValue(NumDstBits).sext(NumSrcBits);
+ if (SDValue SMin = MatchMinMax(In, ISD::SMIN, SignedMax)) {
+ if (SDValue SMax = MatchMinMax(SMin, ISD::SMAX, SignedMin)) {
+ return SMax;
+ }
+ }
+ if (SDValue SMax = MatchMinMax(In, ISD::SMAX, SignedMin)) {
+ if (SDValue SMin = MatchMinMax(SMax, ISD::SMIN, SignedMax)) {
+ return SMin;
+ }
+ }
+ return SDValue();
+}
+
SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
SDValue N0 = N->getOperand(0);
EVT VT = N->getValueType(0);
@@ -14874,6 +14992,18 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
bool isLE = DAG.getDataLayout().isLittleEndian();
SDLoc DL(N);
+ if (!LegalOperations && N->getOpcode() == ISD::TRUNCATE) {
+ if (TLI.isOperationLegalOrCustom(ISD::TRUNCATE_SSAT, SrcVT)) {
+ if (SDValue SSatVal = detectSSatPattern(N0, VT))
+ return DAG.getNode(ISD::TRUNCATE_SSAT, DL, VT, SSatVal);
+ }
+
+ if (TLI.isOperationLegalOrCustom(ISD::TRUNCATE_USAT, SrcVT)) {
+ if (SDValue USatVal = detectUSatPattern(N0, VT, DAG, DL))
+ return DAG.getNode(ISD::TRUNCATE_USAT, DL, VT, USatVal);
+ }
+ }
+
// trunc(undef) = undef
if (N0.isUndef())
return DAG.getUNDEF(VT);
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
index cc8de3a217f82..d3ad6c8acf4f1 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
@@ -380,6 +380,8 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
case ISD::SIGN_EXTEND_VECTOR_INREG: return "sign_extend_vector_inreg";
case ISD::ZERO_EXTEND_VECTOR_INREG: return "zero_extend_vector_inreg";
case ISD::TRUNCATE: return "truncate";
+ case ISD::TRUNCATE_SSAT: return "truncate_ssat";
+ case ISD::TRUNCATE_USAT: return "truncate_usat";
case ISD::FP_ROUND: return "fp_round";
case ISD::STRICT_FP_ROUND: return "strict_fp_round";
case ISD::FP_EXTEND: return "fp_extend";
diff --git a/llvm/lib/CodeGen/TargetLoweringBase.cpp b/llvm/lib/CodeGen/TargetLoweringBase.cpp
index bf031c00a2449..3e855d5e450df 100644
--- a/llvm/lib/CodeGen/TargetLoweringBase.cpp
+++ b/llvm/lib/CodeGen/TargetLoweringBase.cpp
@@ -718,6 +718,10 @@ void TargetLoweringBase::initActions() {
// Absolute difference
setOperationAction({ISD::ABDS, ISD::ABDU}, VT, Expand);
+ // Saturated trunc
+ setOperationAction(ISD::TRUNCATE_SSAT, VT, Expand);
+ setOperationAction(ISD::TRUNCATE_USAT, VT, Expand);
+
// These default to Expand so they will be expanded to CTLZ/CTTZ by default.
setOperationAction({ISD::CTLZ_ZERO_UNDEF, ISD::CTTZ_ZERO_UNDEF}, VT,
Expand);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index df9b0ae1a632f..504bbaed1c8aa 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1274,6 +1274,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::AVGCEILU, VT, Legal);
setOperationAction(ISD::ABDS, VT, Legal);
setOperationAction(ISD::ABDU, VT, Legal);
+ setOperationAction(ISD::TRUNCATE_SSAT, VT, Legal);
+ setOperationAction(ISD::TRUNCATE_USAT, VT, Legal);
}
// Vector reductions
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index dd11f74882115..322219607407b 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -5343,9 +5343,13 @@ def VImm8000: PatLeaf<(AArch64mvni_msl (i32 127), (i32 264))>;
// trunc(umin(X, 255)) -> UQXTRN v8i8
def : Pat<(v8i8 (trunc (umin (v8i16 V128:$Vn), (v8i16 VImmFF)))),
(UQXTNv8i8 V128:$Vn)>;
+def : Pat<(v8i8 (truncusat (v8i16 V128:$Vn))),
+ (UQXTNv8i8 V128:$Vn)>;
// trunc(umin(X, 65535)) -> UQXTRN v4i16
def : Pat<(v4i16 (trunc (umin (v4i32 V128:$Vn), (v4i32 VImmFFFF)))),
(UQXTNv4i16 V128:$Vn)>;
+def : Pat<(v4i16 (truncusat (v4i32 V128:$Vn))),
+ (UQXTNv4i16 V128:$Vn)>;
// trunc(smin(smax(X, -128), 128)) -> SQXTRN
// with reversed min/max
def : Pat<(v8i8 (trunc (smin (smax (v8i16 V128:$Vn), (v8i16 VImm80)),
@@ -5354,6 +5358,8 @@ def : Pat<(v8i8 (trunc (smin (smax (v8i16 V128:$Vn), (v8i16 VImm80)),
def : Pat<(v8i8 (trunc (smax (smin (v8i16 V128:$Vn), (v8i16 VImm7F)),
(v8i16 VImm80)))),
(SQXTNv8i8 V128:$Vn)>;
+def : Pat<(v8i8 (truncssat (v8i16 V128:$Vn))),
+ (SQXTNv8i8 V128:$Vn)>;
// trunc(smin(smax(X, -32768), 32767)) -> SQXTRN
// with reversed min/max
def : Pat<(v4i16 (trunc (smin (smax (v4i32 V128:$Vn), (v4i32 VImm8000)),
@@ -5362,6 +5368,8 @@ def : Pat<(v4i16 (trunc (smin (smax (v4i32 V128:$Vn), (v4i32 VImm8000)),
def : Pat<(v4i16 (trunc (smax (smin (v4i32 V128:$Vn), (v4i32 VImm7FFF)),
(v4i32 VImm8000)))),
(SQXTNv4i16 V128:$Vn)>;
+def : Pat<(v4i16 (truncssat (v4i32 V128:$Vn))),
+ (SQXTNv4i16 V128:$Vn)>;
// concat_vectors(Vd, trunc(smin(smax Vm, -128), 127) ~> SQXTN2(Vd, Vn)
// with reversed min/max
@@ -5375,6 +5383,10 @@ def : Pat<(v16i8 (concat_vectors
(v8i8 (trunc (smax (smin (v8i16 V128:$Vn), (v8i16 VImm7F)),
(v8i16 VImm80)))))),
(SQXTNv16i8 (INSERT_SUBREG (IMPLICIT_DEF), V64:$Vd, dsub), V128:$Vn)>;
+def : Pat<(v16i8 (concat_vectors
+ (v8i8 V64:$Vd),
+ (v8i8 (truncssat (v8i16 V128:$Vn))))),
+ (SQXTNv16i8 (INSERT_SUBREG (IMPLICIT_DEF), V64:$Vd, dsub), V128:$Vn)>;
// concat_vectors(Vd, trunc(smin(smax Vm, -32768), 32767) ~> SQXTN2(Vd, Vn)
// with reversed min/max
@@ -5388,6 +5400,10 @@ def : Pat<(v8i16 (concat_vectors
(v4i16 (trunc (smax (smin (v4i32 V128:$Vn), (v4i32 VImm7FFF)),
(v4i32 VImm8000)))))),
(SQXTNv8i16 (INSERT_SUBREG (IMPLICIT_DEF), V64:$Vd, dsub), V128:$Vn)>;
+def : Pat<(v8i16 (concat_vectors
+ (v4i16 V64:$Vd),
+ (v4i16 (truncssat (v4i32 V128:$Vn))))),
+ (SQXTNv8i16 (INSERT_SUBREG (IMPLICIT_DEF), V64:$Vd, dsub), V128:$Vn)>;
// Select BSWAP vector instructions into REV instructions
def : Pat<(v4i16 (bswap (v4i16 V64:$Rn))),
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 953196a586b6e..3b54416d1b5b2 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -853,7 +853,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
// Integer VTs are lowered as a series of "RISCVISD::TRUNCATE_VECTOR_VL"
// nodes which truncate by one power of two at a time.
- setOperationAction(ISD::TRUNCATE, VT, Custom);
+ setOperationAction(
+ {ISD::TRUNCATE, ISD::TRUNCATE_SSAT, ISD::TRUNCATE_USAT}, VT, Custom);
// Custom-lower insert/extract operations to simplify patterns.
setOperationAction({ISD::INSERT_VECTOR_ELT, ISD::EXTRACT_VECTOR_ELT}, VT,
@@ -1168,7 +1169,9 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setOperationAction(ISD::SELECT, VT, Custom);
- setOperationAction(ISD::TRUNCATE, VT, Custom);
+ setOperationAction(
+ {ISD::TRUNCATE, ISD::TRUNCATE_SSAT, ISD::TRUNCATE_USAT}, VT,
+ Custom);
setOperationAction(ISD::BITCAST, VT, Custom);
@@ -1479,8 +1482,11 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setTargetDAGCombine({ISD::UMAX, ISD::UMIN, ISD::SMAX, ISD::SMIN});
if ((Subtarget.hasStdExtZbs() && Subtarget.is64Bit()) ||
- Subtarget.hasStdExtV())
+ Subtarget.hasStdExtV()) {
setTargetDAGCombine(ISD::TRUNCATE);
+ setTargetDAGCombine(ISD::TRUNCATE_SSAT);
+ setTargetDAGCombine(ISD::TRUNCATE_USAT);
+ }
if (Subtarget.hasStdExtZbkb())
setTargetDAGCombine(ISD::BITREVERSE);
@@ -6092,7 +6098,7 @@ static bool hasMergeOp(unsigned Opcode) {
Opcode <= RISCVISD::LAST_RISCV_STRICTFP_OPCODE &&
"not a RISC-V target specific op");
static_assert(RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP ==
- 130 &&
+ 132 &&
RISCVISD::LAST_RISCV_STRICTFP_OPCODE -
ISD::FIRST_TARGET_STRICTFP_OPCODE ==
21 &&
@@ -6118,7 +6124,7 @@ static bool hasMaskOp(unsigned Opcode) {
Opcode <= RISCVISD::LAST_RISCV_STRICTFP_OPCODE &&
"not a RISC-V target specific op");
static_assert(RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP ==
- 130 &&
+ 132 &&
RISCVISD::LAST_RISCV_STRICTFP_OPCODE -
ISD::FIRST_TARGET_STRICTFP_OPCODE ==
21 &&
@@ -6389,6 +6395,8 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
return DAG.getNode(RISCVISD::BREV8, DL, VT, BSwap);
}
case ISD::TRUNCATE:
+ case ISD::TRUNCATE_SSAT:
+ case ISD::TRUNCATE_USAT:
// Only custom-lower vector truncates
if (!Op.getSimpleValueType().isVector())
return Op;
@@ -8275,11 +8283,15 @@ SDValue RISCVTargetLowering::lowerVectorTruncLike(SDValue Op,
LLVMContext &Context = *DAG.getContext();
const ElementCount Count = ContainerVT.getVectorElementCount();
+ unsigned NewOpc = RISCVISD::TRUNCATE_VECTOR_VL;
+ if (Op.getOpcode() == ISD::TRUNCATE_SSAT)
+ NewOpc = RISCVISD::TRUNCATE_VECTOR_VL_SSAT;
+ else if (Op.getOpcode() == ISD::TRUNCATE_USAT)
+ NewOpc = RISCVISD::TRUNCATE_VECTOR_VL_USAT;
do {
SrcEltVT = MVT::getIntegerVT(SrcEltVT.getSizeInBits() / 2);
EVT ResultVT = EVT::getVectorVT(Context, SrcEltVT, Count);
- Result = DAG.getNode(RISCVISD::TRUNCATE_VECTOR_VL, DL, ResultVT, Result,
- Mask, VL);
+ Result = DAG.getNode(NewOpc, DL, ResultVT, Result, Mask, VL);
} while (SrcEltVT != DstEltVT);
if (SrcVT.isFixedLengthVector())
@@ -16512,7 +16524,9 @@ static SDValue combineTruncOfSraSext(SDNode *N, SelectionDAG &DAG) {
// minimum value.
static SDValue combineTruncToVnclip(SDNode *N, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
- assert(N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL);
+ assert(N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL ||
+ N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL_SSAT ||
+ N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL_USAT);
MVT VT = N->getSimpleValueType(0);
@@ -16617,9 +16631,11 @@ static SDValue combineTruncToVnclip(SDNode *N, SelectionDAG &DAG,
SDValue Val;
unsigned ClipOpc;
- if ((Val = DetectUSatPattern(Src)))
+
+ Val = N->getOperand(0);
+ if (N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL_USAT)
ClipOpc = RISCVISD::VNCLIPU_VL;
- else if ((Val = DetectSSatPattern(Src)))
+ else if (N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL_SSAT)
ClipOpc = RISCVISD::VNCLIP_VL;
else
return SDValue();
@@ -16857,6 +16873,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
}
return SDValue();
case RISCVISD::TRUNCATE_VECTOR_VL:
+ case RISCVISD::TRUNCATE_VECTOR_VL_SSAT:
+ case RISCVISD::TRUNCATE_VECTOR_VL_USAT:
if (SDValue V = combineTruncOfSraSext(N, DAG))
return V;
return combineTruncToVnclip(N, DAG, Subtarget);
@@ -20433,6 +20451,8 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(SPLAT_VECTOR_SPLIT_I64_VL)
NODE_NAME_CASE(READ_VLENB)
NODE_NAME_CASE(TRUNCATE_VECTOR_VL)
+ NODE_NAME_CASE(TRUNCATE_VECTOR_VL_SSAT)
+ NODE_NAME_CASE(TRUNCATE_VECTOR_VL_USAT)
NODE_NAME_CASE(VSLIDEUP_VL)
NODE_NAME_CASE(VSLIDE1UP_VL)
NODE_NAME_CASE(VSLIDEDOWN_VL)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 0b0ad9229f0b3..3d582fcdaf64b 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -181,6 +181,8 @@ enum NodeType : unsigned {
// Truncates a RVV integer vector by one power-of-two. Carries both an extra
// mask and VL operand.
TRUNCATE_VECTOR_VL,
+ TRUNCATE_VECTOR_VL_SSAT,
+ TRUNCATE_VECTOR_VL_USAT,
// Matches the semantics of vslideup/vslidedown. The first operand is the
// pass-thru operand, the second is the source vector, the third is the XLenVT
// index (either constant or non-constant), the fourth is the mask, the fifth
|
@@ -804,6 +804,9 @@ enum NodeType { | |||
|
|||
/// TRUNCATE - Completely drop the high bits. | |||
TRUNCATE, | |||
/// TRUNCATE_[SU]SAT - Truncate for saturated operand | |||
TRUNCATE_SSAT, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is TRUNCATE_SSAT signed input to signed result?
Is TRUNCATE_USAT unsigned input to unsigned result?
I ask because X86's packuswb instructon does signed input to unsigned result so its worth being clearing here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is my opinion and needs confirmation.
It is true that SSAT/USAT is seperated for indicate signed, but 'truncate_[us]sat' means that the range of values for 'truncate' is in the range of values that don't have to care about the sign bit.
We also don't support unsigned type variables, right?
so, I think it is ok that 'truncate_ssat' instruction doesn't care about the type of target value, just return it to result type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variables don't have signed or unsigned type but operations can interpret their input as signed or unsigned.
I believe as you have defined it TRUNCATE_USAT will interpret the input as an unsigned value and produce an unsigned result in the destination type.
The x86 packuswb interprets the input as signed 16 bits and produces an unsigned 8 bit result. If the input is negative it will return 0. That is a different operation than either operation defined here. It's equivalent to smax with 0 followed by truncate_usat. I could imagine having that operation as a single node. Not suggesting for this patch.
I just want the semantics of the new opcodes documented.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AArch64 has a SQXTUN node too: https://docsmirror.github.io/A64/2023-09/sqxtun_advsimd.html
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So should we have 3 truncsat nodes? I mentioned this on #85903 but hadn't realised other targets had something similar to PACKUS
TRUNCATE_SSAT_S, // saturate signed input to signed result - truncate(smin(smax(x)))
TRUNCATE_SSAT_U, // saturate signed input to unsigned result - truncate(smin(smax(x,0)))
TRUNCATE_USAT_U, // saturate unsigned input to unsigned result - truncate(umin(x))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@RKSimon should add that ISD? is it not necessary for unsigned input to signed input?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we have targets that support TRUNCATE_USAT_S then by all means add it.
return SDValue(); | ||
} | ||
|
||
SDValue DAGCombiner::visitTRUNCATE_SSAT(SDNode *N) { return SDValue(); } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we just not add the visit function instead of adding an empty function?
@@ -5343,9 +5343,13 @@ def VImm8000: PatLeaf<(AArch64mvni_msl (i32 127), (i32 264))>; | |||
// trunc(umin(X, 255)) -> UQXTRN v8i8 | |||
def : Pat<(v8i8 (trunc (umin (v8i16 V128:$Vn), (v8i16 VImmFF)))), | |||
(UQXTNv8i8 V128:$Vn)>; | |||
def : Pat<(v8i8 (truncusat (v8i16 V128:$Vn))), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can the pattern above be dropped?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not possible with the current logic. However, I think it would be possible with additional logic. Would you like to see it included in this patch?
SelectionDAG has 22 nodes:
t0: ch,glue = EntryToken
t2: v8f16,ch = CopyFromReg t0, Register:v8f16 %0
t16: v8i16 = fp_to_sint_sat t2, ValueType:ch:i16
t19: v8i16 = smin t16, t18
t22: v8i16 = smax t19, t21
t23: v8i8 = truncate t22
t4: v8f16,ch = CopyFromReg t0, Register:v8f16 %1
t24: v8i16 = fp_to_sint_sat t4, ValueType:ch:i16
t25: v8i16 = smin t24, t18
t26: v8i16 = smax t25, t21
t27: v8i8 = truncate t26
t13: v16i8 = concat_vectors t23, t27
t9: ch,glue = CopyToReg t0, Register:v16i8 $q0, t13
I think it need support for fp_to_[su]int_sat
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What are the t18 and t21 nodes. It looks like you only pasted part of the DAG.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1173:SelectionDAG has 22 nodes:
1174- t0: ch,glue = EntryToken
1175- t2: v8f16,ch = CopyFromReg t0, Register:v8f16 %0
1176- t16: v8i16 = fp_to_sint_sat t2, ValueType:ch:i16
1177- t19: v8i16 = smin t16, t18
1178- t22: v8i16 = smax t19, t21
1179- t23: v8i8 = truncate t22
1180- t4: v8f16,ch = CopyFromReg t0, Register:v8f16 %1
1181- t24: v8i16 = fp_to_sint_sat t4, ValueType:ch:i16
1182- t25: v8i16 = smin t24, t18
1183- t26: v8i16 = smax t25, t21
1184- t27: v8i8 = truncate t26
1185- t13: v16i8 = concat_vectors t23, t27
1186- t9: ch,glue = CopyToReg t0, Register:v16i8 $q0, t13
1187- t18: v8i16 = BUILD_VECTOR Constant:i32<127>, Constant:i32<127>, Constant:i32<127>, Constant:i32<127>, Constant:i32<127>, Constant:i32<127>, Constant:i32<127>, Constant:i32<127>
1188- t21: v8i16 = BUILD_VECTOR Constant:i32<65408>, Constant:i32<65408>, Constant:i32<65408>, Constant:i32<65408>, Constant:i32<65408>, Constant:i32<65408>, Constant:i32<65408>, Constant:i32<65408>
1189- t10: ch = AArch64ISD::RET_GLUE t9, Register:v16i8 $q0, t9:1
It was at the bottom.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@topperc I tested this by implementing additional code to support fp_to_[su]int_sat
. As a result, if LegalOperations
is unchecked, the associated Pattern can be deleted.
Should we continue to check LegalOperations
? I can not sure because I'm not a veteran.
if ((Val = DetectUSatPattern(Src))) | ||
|
||
Val = N->getOperand(0); | ||
if (N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL_USAT) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is TRUNCATE_VECTOR_VL_USAT the same as RISCVISD::VNCLIPU_VL?
ClipOpc = RISCVISD::VNCLIPU_VL; | ||
else if ((Val = DetectSSatPattern(Src))) | ||
else if (N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL_SSAT) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is TRUNCATE_VECTOR_VL_SSAT the same as VNCLIP_VL?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nevermind. It's not the same because VNCLIP has a shift amount.
@@ -804,6 +804,9 @@ enum NodeType { | |||
|
|||
/// TRUNCATE - Completely drop the high bits. | |||
TRUNCATE, | |||
/// TRUNCATE_[SU]SAT - Truncate for saturated operand | |||
TRUNCATE_SSAT, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AArch64 has a SQXTUN node too: https://docsmirror.github.io/A64/2023-09/sqxtun_advsimd.html
setTargetDAGCombine(ISD::TRUNCATE); | ||
setTargetDAGCombine(ISD::TRUNCATE_SSAT); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suspect these 2 lines aren't needed.
I just submitted #100173 to remove the RISCVISD::VNCLIP* opcodes in favor of the TRUNCATE_VECTOR_VL_*SAT nodes added in this patch. This should allow us to remove the translating from TRUNCATE_VECTOR_VL_SAT to RISCVISD::VNCLIP. |
…USAT opcodes (#100173) These new opcodes drop the shift amount, rounding mode, and passthru. Making them exactly like TRUNCATE_VECTOR_VL. The shift amount, rounding mode, and passthru are added in isel patterns similar to how we translate TRUNCATE_VECTOR_VL to vnsrl with a shift of 0. This should simplify #99418 a little.
…USAT opcodes (#100173) Summary: These new opcodes drop the shift amount, rounding mode, and passthru. Making them exactly like TRUNCATE_VECTOR_VL. The shift amount, rounding mode, and passthru are added in isel patterns similar to how we translate TRUNCATE_VECTOR_VL to vnsrl with a shift of 0. This should simplify #99418 a little. Test Plan: Reviewers: Subscribers: Tasks: Tags: Differential Revision: https://phabricator.intern.facebook.com/D60251265
✅ With the latest revision this PR passed the C/C++ code formatter. |
llvm/test/CodeGen/AArch64/qmovn.ll
Outdated
@@ -292,15 +292,15 @@ entry: | |||
|
|||
; Test the (concat_vectors (X), (trunc(umin(smax(Y, 0), 2^n))))) pattern. | |||
|
|||
; TODO: %min is a value between 0 and 255 and is within the unsigned range of i8. | |||
; So it is saturated truncate. we have an optimization opportunity. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm pretty sure this should be matching a sqxtun2, that was the intent. I believe because the lower limit is already clamped, that the upper smin is equivalent to umin.
https://godbolt.org/z/e7ne31TYb
You can see in that example that the midend turned smin into umin.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(right place :))
agree. I also added a comment with that intention and I'm thinking maybe adding code to DAGCombiner can solve it.
do you think I should include this in this patch as well?
SDValue DAGCombiner::visitTRUNCATE_USAT(SDNode *N) { | ||
EVT VT = N->getValueType(0); | ||
SDValue N0 = N->getOperand(0); | ||
SDValue FPInstr = N0.getOpcode() == ISD::SMAX ? N0.getOperand(0) : N0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you need to check the other operand of this SMAX?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, missed it. I'll fix it.
SDValue Sat = DAG.getNode(ISD::FP_TO_UINT_SAT, SDLoc(FPInstr), VT, | ||
FPInstr.getOperand(0), | ||
DAG.getValueType(VT.getScalarType())); | ||
return Sat; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return DAG.getNode... no need for temporary variable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure how I can do it. could you give me little advise please?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@topperc how about now? do you think I'm doing properly?
// fold satruated truncate | ||
if (SDValue SaturatedTR = foldToSaturated(N, VT, N0, SrcVT, DL, TLI, DAG)) { | ||
return SaturatedTR; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Drop curly braces
@@ -814,6 +814,13 @@ enum NodeType { | |||
|
|||
/// TRUNCATE - Completely drop the high bits. | |||
TRUNCATE, | |||
/// TRUNCATE_[SU]SAT - Truncate for saturated operand | |||
TRUNCATE_SSAT_S, // saturate signed input to signed result - | |||
// truncate(smin(smax(x))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please spell out the parameters for my eyes? [s|u][min|max] have two parameters.
(UQXTNv8i8 V128:$Vn)>; | ||
// trunc(umin(X, 65535)) -> UQXTRN v4i16 | ||
def : Pat<(v4i16 (trunc (umin (v4i32 V128:$Vn), (v4i32 VImmFFFF)))), | ||
def : Pat<(v4i16 (truncusat_u (v4i32 V128:$Vn))), | ||
(UQXTNv4i16 V128:$Vn)>; | ||
// trunc(smin(smax(X, -128), 128)) -> SQXTRN | ||
// with reversed min/max |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The "reversed min/max" lines can be removed now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I wasn't around yesterday, thanks for the updates. The code looks good to me now, if others agree.
@RKSimon @davemgreen Is it okay to reduce commits? |
Sure - as I mentioned above the commit summary at the top needs editing. Once that's done I'll approve. |
A truncate is considered saturated if no additional conversion is required between the target and return values. If the target is saturated when attempting to truncate from a vector, there is an opportunity to optimize it. Previously, each architecture had its own attempt at optimization, leading to redundant code. This patch implements common logic by introducing three new ISDs: `ISD::TRUNCATE_SSAT_S`: When the operand is a signed value and the range of values matches the range of signed values of the destination type. `ISD::TRUNCATE_SSAT_U`: When the operand is a signed value and the range of values matches the range of unsigned values of the destination type. `ISD::TRUNCATE_USAT_U`: When the operand is an unsigned value and the range of values matches the range of unsigned values of the destination type. These ISDs indicate a saturated truncate. Fixes llvm#85903
5cfe1a8
to
7db4887
Compare
Inspired by #99418 (which hopefully we can replace this code with at some point)
Can you update the description in github with the one in 8d81896 (or the others if you want to combine them)? LLVM uses a squash-and-merge approach to committing PR's, so they will be squashed into a single commit. |
do you mean my written message when I request PR? if it is, I done. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one final minor
@@ -1908,6 +1909,8 @@ SDValue DAGCombiner::visit(SDNode *N) { | |||
case ISD::ZERO_EXTEND_VECTOR_INREG: | |||
case ISD::ANY_EXTEND_VECTOR_INREG: return visitEXTEND_VECTOR_INREG(N); | |||
case ISD::TRUNCATE: return visitTRUNCATE(N); | |||
case ISD::TRUNCATE_SSAT_U: | |||
case ISD::TRUNCATE_USAT_U: return visitTRUNCATE_USAT(N); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be called visitTRUNCATE_SAT_U? Or should we just have a visitTRUNCATE_SAT call and handle the unsigned cases inside it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did we not mean to change this to just case ISD::TRUNCATE_USAT_U: return visitTRUNCATE_USAT(N);
, without the TRUNCATE_SSAT_U case? (Sorry if I missed that). Otherwise it will change TRUNCATE_SSAT_U(FP_TO_UINT(x))
to FP_TO_UINT_SAT(x)
, which will not clamp to the same bounds.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
visitTRUNCATE_USAT has a small task, so I don't see the need to separate SSAT and USAT.
Is it LLVM's way to reduce unnecessary separation and separate them if they are reasonably large?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, we can generalize it later if the need arises - but we need to confirm if we should be handling the TRUNCATE_SSAT_U case or not (do we have test coverage?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If that's the case, then we don't have a test for truncate_ssat_u. As @davemgreen commented, it is right to change to call visitTRUNCATE_USAT() in case TRUNCATE_USAT_U.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should I change it so that only case ISD::TRUNCATE_USAT_U: return visitTRUNCATE_USAT(N);
remains?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think so. Perhaps call it visitTRUNCATE_USAT_U too?
7db4887
to
a86ff34
Compare
Add support for saturated truncate with the following changes: - Add action to Legal for types v8i16, v4i32, and v2i64 - Implement `isTypeDesirableForOp` to check for truncate conversions - Add patterns for saturated truncate of supported types
Add support for saturated truncate by implementing the following changes: - Add `TRUNCATE_[SU]SAT_[SU]` to the Action target of `TRUNCATE` - Add `TRUNCATE_[SU]SAT_[SU]` to the TargetLowering target of `TRUNCATE` - Convert `TRUNCATE_SSAT_S` to `TRUNCATE_VECTOR_VL_SSAT` - Convert `TRUNCATE_[SU]SAT_U` to `TRUNCATE_VECTOR_VL_USAT`
a86ff34
to
9a40cec
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. LGTM
@ParkHanbum Are you happy for me to commit this? |
@RKSimon I feel very honored, Sir! please do it! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
unsigned NewOpc; | ||
if (Opc == ISD::TRUNCATE_SSAT_S) | ||
NewOpc = RISCVISD::TRUNCATE_VECTOR_VL_SSAT; | ||
else if (Opc == ISD::TRUNCATE_SSAT_U || Opc == ISD::TRUNCATE_USAT_U) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How are we lowering 2 different opcodes to the same RISC-V instruction?
A truncate is considered saturated if no additional conversion is
required between the target and return values. If the target is
saturated when attempting to truncate from a vector, there is an
opportunity to optimize it.
Previously, each architecture had its own attempt at optimization,
leading to redundant code. This patch implements common logic by
introducing three new ISDs:
ISD::TRUNCATE_SSAT_S
: When the operand is a signed value andthe range of values matches the range of signed values of the
destination type.
ISD::TRUNCATE_SSAT_U
: When the operand is a signed value andthe range of values matches the range of unsigned values of the
destination type.
ISD::TRUNCATE_USAT_U
: When the operand is an unsigned value andthe range of values matches the range of unsigned values of the
destination type.
These ISDs indicate a saturated truncate.
Fixes #85903