-
Notifications
You must be signed in to change notification settings - Fork 12.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DAG] Support saturated truncate #99418
Changes from all commits
8d81896
7df00c5
5477448
9780ddc
345a104
9a40cec
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -486,6 +486,7 @@ namespace { | |
SDValue visitSIGN_EXTEND_INREG(SDNode *N); | ||
SDValue visitEXTEND_VECTOR_INREG(SDNode *N); | ||
SDValue visitTRUNCATE(SDNode *N); | ||
SDValue visitTRUNCATE_USAT_U(SDNode *N); | ||
SDValue visitBITCAST(SDNode *N); | ||
SDValue visitFREEZE(SDNode *N); | ||
SDValue visitBUILD_PAIR(SDNode *N); | ||
|
@@ -1908,6 +1909,7 @@ SDValue DAGCombiner::visit(SDNode *N) { | |
case ISD::ZERO_EXTEND_VECTOR_INREG: | ||
case ISD::ANY_EXTEND_VECTOR_INREG: return visitEXTEND_VECTOR_INREG(N); | ||
case ISD::TRUNCATE: return visitTRUNCATE(N); | ||
case ISD::TRUNCATE_USAT_U: return visitTRUNCATE_USAT_U(N); | ||
case ISD::BITCAST: return visitBITCAST(N); | ||
case ISD::BUILD_PAIR: return visitBUILD_PAIR(N); | ||
case ISD::FADD: return visitFADD(N); | ||
|
@@ -13203,7 +13205,9 @@ SDValue DAGCombiner::matchVSelectOpSizesWithSetCC(SDNode *Cast) { | |
unsigned CastOpcode = Cast->getOpcode(); | ||
assert((CastOpcode == ISD::SIGN_EXTEND || CastOpcode == ISD::ZERO_EXTEND || | ||
CastOpcode == ISD::TRUNCATE || CastOpcode == ISD::FP_EXTEND || | ||
CastOpcode == ISD::FP_ROUND) && | ||
CastOpcode == ISD::TRUNCATE_SSAT_S || | ||
CastOpcode == ISD::TRUNCATE_SSAT_U || | ||
CastOpcode == ISD::TRUNCATE_USAT_U || CastOpcode == ISD::FP_ROUND) && | ||
"Unexpected opcode for vector select narrowing/widening"); | ||
|
||
// We only do this transform before legal ops because the pattern may be | ||
|
@@ -14915,6 +14919,132 @@ SDValue DAGCombiner::visitEXTEND_VECTOR_INREG(SDNode *N) { | |
return SDValue(); | ||
} | ||
|
||
SDValue DAGCombiner::visitTRUNCATE_USAT_U(SDNode *N) { | ||
EVT VT = N->getValueType(0); | ||
SDValue N0 = N->getOperand(0); | ||
|
||
std::function<SDValue(SDValue)> MatchFPTOINT = [&](SDValue Val) -> SDValue { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you need to check the other operand of this SMAX? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, missed it. I'll fix it. |
||
if (Val.getOpcode() == ISD::FP_TO_UINT) | ||
davemgreen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return Val; | ||
return SDValue(); | ||
}; | ||
|
||
SDValue FPInstr = MatchFPTOINT(N0); | ||
if (!FPInstr) | ||
return SDValue(); | ||
|
||
EVT FPVT = FPInstr.getOperand(0).getValueType(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. return DAG.getNode... no need for temporary variable. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure how I can do it. could you give me little advise please? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @topperc how about now? do you think I'm doing properly? |
||
if (!DAG.getTargetLoweringInfo().shouldConvertFpToSat(ISD::FP_TO_UINT_SAT, | ||
FPVT, VT)) | ||
return SDValue(); | ||
return DAG.getNode(ISD::FP_TO_UINT_SAT, SDLoc(FPInstr), VT, | ||
FPInstr.getOperand(0), | ||
DAG.getValueType(VT.getScalarType())); | ||
davemgreen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
/// Detect patterns of truncation with unsigned saturation: | ||
/// | ||
/// (truncate (umin (x, unsigned_max_of_dest_type)) to dest_type). | ||
/// Return the source value x to be truncated or SDValue() if the pattern was | ||
/// not matched. | ||
/// | ||
static SDValue detectUSatUPattern(SDValue In, EVT VT) { | ||
unsigned NumDstBits = VT.getScalarSizeInBits(); | ||
unsigned NumSrcBits = In.getScalarValueSizeInBits(); | ||
// Saturation with truncation. We truncate from InVT to VT. | ||
assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation"); | ||
|
||
davemgreen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
SDValue Min; | ||
davemgreen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
APInt UnsignedMax = APInt::getMaxValue(NumDstBits).zext(NumSrcBits); | ||
if (sd_match(In, m_UMin(m_Value(Min), m_SpecificInt(UnsignedMax)))) | ||
return Min; | ||
|
||
return SDValue(); | ||
} | ||
|
||
/// Detect patterns of truncation with signed saturation: | ||
/// (truncate (smin (smax (x, signed_min_of_dest_type), | ||
/// signed_max_of_dest_type)) to dest_type) | ||
/// or: | ||
/// (truncate (smax (smin (x, signed_max_of_dest_type), | ||
/// signed_min_of_dest_type)) to dest_type). | ||
/// | ||
RKSimon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
/// Return the source value to be truncated or SDValue() if the pattern was not | ||
/// matched. | ||
static SDValue detectSSatSPattern(SDValue In, EVT VT) { | ||
unsigned NumDstBits = VT.getScalarSizeInBits(); | ||
unsigned NumSrcBits = In.getScalarValueSizeInBits(); | ||
// Saturation with truncation. We truncate from InVT to VT. | ||
assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation"); | ||
|
||
SDValue Val; | ||
APInt SignedMax = APInt::getSignedMaxValue(NumDstBits).sext(NumSrcBits); | ||
APInt SignedMin = APInt::getSignedMinValue(NumDstBits).sext(NumSrcBits); | ||
|
||
if (sd_match(In, m_SMin(m_SMax(m_Value(Val), m_SpecificInt(SignedMin)), | ||
m_SpecificInt(SignedMax)))) | ||
return Val; | ||
|
||
if (sd_match(In, m_SMax(m_SMin(m_Value(Val), m_SpecificInt(SignedMax)), | ||
davemgreen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
m_SpecificInt(SignedMin)))) | ||
return Val; | ||
davemgreen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
return SDValue(); | ||
} | ||
|
||
/// Detect patterns of truncation with unsigned saturation: | ||
static SDValue detectSSatUPattern(SDValue In, EVT VT, SelectionDAG &DAG, | ||
const SDLoc &DL) { | ||
unsigned NumDstBits = VT.getScalarSizeInBits(); | ||
unsigned NumSrcBits = In.getScalarValueSizeInBits(); | ||
// Saturation with truncation. We truncate from InVT to VT. | ||
assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation"); | ||
|
||
SDValue Val; | ||
APInt UnsignedMax = APInt::getMaxValue(NumDstBits).zext(NumSrcBits); | ||
// Min == 0, Max is unsigned max of destination type. | ||
if (sd_match(In, m_SMax(m_SMin(m_Value(Val), m_SpecificInt(UnsignedMax)), | ||
m_Zero()))) | ||
return Val; | ||
|
||
if (sd_match(In, m_SMin(m_SMax(m_Value(Val), m_Zero()), | ||
m_SpecificInt(UnsignedMax)))) | ||
RKSimon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return Val; | ||
|
||
if (sd_match(In, m_UMin(m_SMax(m_Value(Val), m_Zero()), | ||
m_SpecificInt(UnsignedMax)))) | ||
return Val; | ||
|
||
return SDValue(); | ||
} | ||
|
||
static SDValue foldToSaturated(SDNode *N, EVT &VT, SDValue &Src, EVT &SrcVT, | ||
SDLoc &DL, const TargetLowering &TLI, | ||
SelectionDAG &DAG) { | ||
auto AllowedTruncateSat = [&](unsigned Opc, EVT SrcVT, EVT VT) -> bool { | ||
return (TLI.isOperationLegalOrCustom(Opc, SrcVT) && | ||
TLI.isTypeDesirableForOp(Opc, VT)); | ||
}; | ||
|
||
if (Src.getOpcode() == ISD::SMIN || Src.getOpcode() == ISD::SMAX) { | ||
if (AllowedTruncateSat(ISD::TRUNCATE_SSAT_S, SrcVT, VT)) | ||
if (SDValue SSatVal = detectSSatSPattern(Src, VT)) | ||
return DAG.getNode(ISD::TRUNCATE_SSAT_S, DL, VT, SSatVal); | ||
if (AllowedTruncateSat(ISD::TRUNCATE_SSAT_U, SrcVT, VT)) | ||
if (SDValue SSatVal = detectSSatUPattern(Src, VT, DAG, DL)) | ||
return DAG.getNode(ISD::TRUNCATE_SSAT_U, DL, VT, SSatVal); | ||
} else if (Src.getOpcode() == ISD::UMIN) { | ||
if (AllowedTruncateSat(ISD::TRUNCATE_SSAT_U, SrcVT, VT)) | ||
if (SDValue SSatVal = detectSSatUPattern(Src, VT, DAG, DL)) | ||
return DAG.getNode(ISD::TRUNCATE_SSAT_U, DL, VT, SSatVal); | ||
davemgreen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if (AllowedTruncateSat(ISD::TRUNCATE_USAT_U, SrcVT, VT)) | ||
if (SDValue USatVal = detectUSatUPattern(Src, VT)) | ||
return DAG.getNode(ISD::TRUNCATE_USAT_U, DL, VT, USatVal); | ||
} | ||
|
||
return SDValue(); | ||
} | ||
|
||
SDValue DAGCombiner::visitTRUNCATE(SDNode *N) { | ||
davemgreen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
SDValue N0 = N->getOperand(0); | ||
EVT VT = N->getValueType(0); | ||
|
@@ -14930,6 +15060,10 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) { | |
if (N0.getOpcode() == ISD::TRUNCATE) | ||
return DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(0)); | ||
|
||
// fold saturated truncate | ||
if (SDValue SaturatedTR = foldToSaturated(N, VT, N0, SrcVT, DL, TLI, DAG)) | ||
RKSimon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return SaturatedTR; | ||
|
||
// fold (truncate c1) -> c1 | ||
if (SDValue C = DAG.FoldConstantArithmetic(ISD::TRUNCATE, DL, VT, {N0})) | ||
return C; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be called visitTRUNCATE_SAT_U? Or should we just have a visitTRUNCATE_SAT call and handle the unsigned cases inside it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did we not mean to change this to just
case ISD::TRUNCATE_USAT_U: return visitTRUNCATE_USAT(N);
, without the TRUNCATE_SSAT_U case? (Sorry if I missed that). Otherwise it will changeTRUNCATE_SSAT_U(FP_TO_UINT(x))
toFP_TO_UINT_SAT(x)
, which will not clamp to the same bounds.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
visitTRUNCATE_USAT has a small task, so I don't see the need to separate SSAT and USAT.
Is it LLVM's way to reduce unnecessary separation and separate them if they are reasonably large?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, we can generalize it later if the need arises - but we need to confirm if we should be handling the TRUNCATE_SSAT_U case or not (do we have test coverage?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If that's the case, then we don't have a test for truncate_ssat_u. As @davemgreen commented, it is right to change to call visitTRUNCATE_USAT() in case TRUNCATE_USAT_U.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should I change it so that only
case ISD::TRUNCATE_USAT_U: return visitTRUNCATE_USAT(N);
remains?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think so. Perhaps call it visitTRUNCATE_USAT_U too?