Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ISel: introduce vector ISD::LRINT, ISD::LLRINT; custom RISCV lowering #66924

Merged
merged 5 commits into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions llvm/docs/LangRef.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15760,7 +15760,8 @@ Syntax:
"""""""

This is an overloaded intrinsic. You can use ``llvm.lrint`` on any
floating-point type. Not all targets support all types however.
floating-point type or vector of floating-point type. Not all targets
support all types however.

::

Expand Down Expand Up @@ -15804,7 +15805,8 @@ Syntax:
"""""""

This is an overloaded intrinsic. You can use ``llvm.llrint`` on any
floating-point type. Not all targets support all types however.
floating-point type or vector of floating-point type. Not all targets
support all types however.

::

Expand Down
6 changes: 6 additions & 0 deletions llvm/include/llvm/CodeGen/BasicTTIImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1847,6 +1847,12 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
case Intrinsic::rint:
ISD = ISD::FRINT;
break;
case Intrinsic::lrint:
ISD = ISD::LRINT;
break;
case Intrinsic::llrint:
ISD = ISD::LLRINT;
break;
case Intrinsic::round:
ISD = ISD::FROUND;
break;
Expand Down
20 changes: 20 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,7 @@ namespace {
SDValue visitUINT_TO_FP(SDNode *N);
SDValue visitFP_TO_SINT(SDNode *N);
SDValue visitFP_TO_UINT(SDNode *N);
SDValue visitXRINT(SDNode *N);
SDValue visitFP_ROUND(SDNode *N);
SDValue visitFP_EXTEND(SDNode *N);
SDValue visitFNEG(SDNode *N);
Expand Down Expand Up @@ -1911,6 +1912,7 @@ void DAGCombiner::Run(CombineLevel AtLevel) {
}

SDValue DAGCombiner::visit(SDNode *N) {
// clang-format off
switch (N->getOpcode()) {
default: break;
case ISD::TokenFactor: return visitTokenFactor(N);
Expand Down Expand Up @@ -2011,6 +2013,8 @@ SDValue DAGCombiner::visit(SDNode *N) {
case ISD::UINT_TO_FP: return visitUINT_TO_FP(N);
case ISD::FP_TO_SINT: return visitFP_TO_SINT(N);
case ISD::FP_TO_UINT: return visitFP_TO_UINT(N);
case ISD::LRINT:
case ISD::LLRINT: return visitXRINT(N);
case ISD::FP_ROUND: return visitFP_ROUND(N);
case ISD::FP_EXTEND: return visitFP_EXTEND(N);
case ISD::FNEG: return visitFNEG(N);
Expand Down Expand Up @@ -2065,6 +2069,7 @@ SDValue DAGCombiner::visit(SDNode *N) {
#include "llvm/IR/VPIntrinsics.def"
return visitVPOp(N);
}
// clang-format on
return SDValue();
}

Expand Down Expand Up @@ -17480,6 +17485,21 @@ SDValue DAGCombiner::visitFP_TO_UINT(SDNode *N) {
return FoldIntToFPToInt(N, DAG);
}

SDValue DAGCombiner::visitXRINT(SDNode *N) {
SDValue N0 = N->getOperand(0);
EVT VT = N->getValueType(0);

// fold (lrint|llrint undef) -> undef
if (N0.isUndef())
return DAG.getUNDEF(VT);

// fold (lrint|llrint c1fp) -> c1
if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N0);

return SDValue();
}

SDValue DAGCombiner::visitFP_ROUND(SDNode *N) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
Expand Down
8 changes: 6 additions & 2 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2198,6 +2198,7 @@ bool DAGTypeLegalizer::PromoteFloatOperand(SDNode *N, unsigned OpNo) {
// to use the promoted float operand. Nodes that produce at least one
// promotion-requiring floating point result have their operands legalized as
// a part of PromoteFloatResult.
// clang-format off
switch (N->getOpcode()) {
default:
#ifndef NDEBUG
Expand All @@ -2209,7 +2210,9 @@ bool DAGTypeLegalizer::PromoteFloatOperand(SDNode *N, unsigned OpNo) {
case ISD::BITCAST: R = PromoteFloatOp_BITCAST(N, OpNo); break;
case ISD::FCOPYSIGN: R = PromoteFloatOp_FCOPYSIGN(N, OpNo); break;
case ISD::FP_TO_SINT:
case ISD::FP_TO_UINT: R = PromoteFloatOp_FP_TO_XINT(N, OpNo); break;
case ISD::FP_TO_UINT:
case ISD::LRINT:
case ISD::LLRINT: R = PromoteFloatOp_UnaryOp(N, OpNo); break;
case ISD::FP_TO_SINT_SAT:
case ISD::FP_TO_UINT_SAT:
R = PromoteFloatOp_FP_TO_XINT_SAT(N, OpNo); break;
Expand All @@ -2218,6 +2221,7 @@ bool DAGTypeLegalizer::PromoteFloatOperand(SDNode *N, unsigned OpNo) {
case ISD::SETCC: R = PromoteFloatOp_SETCC(N, OpNo); break;
case ISD::STORE: R = PromoteFloatOp_STORE(N, OpNo); break;
}
// clang-format on

if (R.getNode())
ReplaceValueWith(SDValue(N, 0), R);
Expand Down Expand Up @@ -2251,7 +2255,7 @@ SDValue DAGTypeLegalizer::PromoteFloatOp_FCOPYSIGN(SDNode *N, unsigned OpNo) {
}

// Convert the promoted float value to the desired integer type
SDValue DAGTypeLegalizer::PromoteFloatOp_FP_TO_XINT(SDNode *N, unsigned OpNo) {
SDValue DAGTypeLegalizer::PromoteFloatOp_UnaryOp(SDNode *N, unsigned OpNo) {
SDValue Op = GetPromotedFloat(N->getOperand(0));
return DAG.getNode(N->getOpcode(), SDLoc(N), N->getValueType(0), Op);
}
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
SDValue PromoteFloatOp_BITCAST(SDNode *N, unsigned OpNo);
SDValue PromoteFloatOp_FCOPYSIGN(SDNode *N, unsigned OpNo);
SDValue PromoteFloatOp_FP_EXTEND(SDNode *N, unsigned OpNo);
SDValue PromoteFloatOp_FP_TO_XINT(SDNode *N, unsigned OpNo);
SDValue PromoteFloatOp_UnaryOp(SDNode *N, unsigned OpNo);
SDValue PromoteFloatOp_FP_TO_XINT_SAT(SDNode *N, unsigned OpNo);
SDValue PromoteFloatOp_STORE(SDNode *N, unsigned OpNo);
SDValue PromoteFloatOp_SELECT_CC(SDNode *N, unsigned OpNo);
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,8 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
case ISD::FCEIL:
case ISD::FTRUNC:
case ISD::FRINT:
case ISD::LRINT:
case ISD::LLRINT:
case ISD::FNEARBYINT:
case ISD::FROUND:
case ISD::FROUNDEVEN:
Expand Down
16 changes: 15 additions & 1 deletion llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ void DAGTypeLegalizer::ScalarizeVectorResult(SDNode *N, unsigned ResNo) {
case ISD::FP_TO_SINT:
case ISD::FP_TO_UINT:
case ISD::FRINT:
case ISD::LRINT:
case ISD::LLRINT:
case ISD::FROUND:
case ISD::FROUNDEVEN:
case ISD::FSIN:
Expand Down Expand Up @@ -681,6 +683,8 @@ bool DAGTypeLegalizer::ScalarizeVectorOperand(SDNode *N, unsigned OpNo) {
case ISD::FP_TO_UINT:
case ISD::SINT_TO_FP:
case ISD::UINT_TO_FP:
case ISD::LRINT:
case ISD::LLRINT:
Res = ScalarizeVecOp_UnaryOp(N);
break;
case ISD::STRICT_SINT_TO_FP:
Expand Down Expand Up @@ -1097,6 +1101,8 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
case ISD::VP_FP_TO_UINT:
case ISD::FRINT:
case ISD::VP_FRINT:
case ISD::LRINT:
case ISD::LLRINT:
case ISD::FROUND:
case ISD::VP_FROUND:
case ISD::FROUNDEVEN:
Expand Down Expand Up @@ -2974,6 +2980,8 @@ bool DAGTypeLegalizer::SplitVectorOperand(SDNode *N, unsigned OpNo) {
case ISD::ZERO_EXTEND:
case ISD::ANY_EXTEND:
case ISD::FTRUNC:
case ISD::LRINT:
case ISD::LLRINT:
Res = SplitVecOp_UnaryOp(N);
break;
case ISD::FLDEXP:
Expand Down Expand Up @@ -4209,6 +4217,8 @@ void DAGTypeLegalizer::WidenVectorResult(SDNode *N, unsigned ResNo) {
case ISD::FLOG2:
case ISD::FNEARBYINT:
case ISD::FRINT:
case ISD::LRINT:
case ISD::LLRINT:
case ISD::FROUND:
case ISD::FROUNDEVEN:
case ISD::FSIN:
Expand Down Expand Up @@ -5958,7 +5968,11 @@ bool DAGTypeLegalizer::WidenVectorOperand(SDNode *N, unsigned OpNo) {
case ISD::STRICT_FSETCCS: Res = WidenVecOp_STRICT_FSETCC(N); break;
case ISD::VSELECT: Res = WidenVecOp_VSELECT(N); break;
case ISD::FLDEXP:
case ISD::FCOPYSIGN: Res = WidenVecOp_UnrollVectorOp(N); break;
case ISD::FCOPYSIGN:
case ISD::LRINT:
case ISD::LLRINT:
Res = WidenVecOp_UnrollVectorOp(N);
break;
case ISD::IS_FPCLASS: Res = WidenVecOp_IS_FPCLASS(N); break;

case ISD::ANY_EXTEND:
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5135,6 +5135,8 @@ bool SelectionDAG::isKnownNeverNaN(SDValue Op, bool SNaN, unsigned Depth) const
case ISD::FROUND:
case ISD::FROUNDEVEN:
case ISD::FRINT:
case ISD::LRINT:
case ISD::LLRINT:
case ISD::FNEARBYINT:
case ISD::FLDEXP: {
if (SNaN)
Expand Down
12 changes: 6 additions & 6 deletions llvm/lib/CodeGen/TargetLoweringBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -873,13 +873,13 @@ void TargetLoweringBase::initActions() {

// These operations default to expand for vector types.
if (VT.isVector())
setOperationAction({ISD::FCOPYSIGN, ISD::SIGN_EXTEND_INREG,
ISD::ANY_EXTEND_VECTOR_INREG,
ISD::SIGN_EXTEND_VECTOR_INREG,
ISD::ZERO_EXTEND_VECTOR_INREG, ISD::SPLAT_VECTOR},
VT, Expand);
setOperationAction(
{ISD::FCOPYSIGN, ISD::SIGN_EXTEND_INREG, ISD::ANY_EXTEND_VECTOR_INREG,
ISD::SIGN_EXTEND_VECTOR_INREG, ISD::ZERO_EXTEND_VECTOR_INREG,
ISD::SPLAT_VECTOR, ISD::LRINT, ISD::LLRINT},
VT, Expand);

// Constrained floating-point operations default to expand.
// Constrained floating-point operations default to expand.
#define DAG_INSTRUCTION(NAME, NARG, ROUND_MODE, INTRINSIC, DAGN) \
setOperationAction(ISD::STRICT_##DAGN, VT, Expand);
#include "llvm/IR/ConstrainedOps.def"
Expand Down
22 changes: 20 additions & 2 deletions llvm/lib/IR/Verifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5669,10 +5669,28 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) {
}
break;
}
case Intrinsic::lround:
case Intrinsic::llround:
case Intrinsic::lrint:
case Intrinsic::llrint: {
Type *ValTy = Call.getArgOperand(0)->getType();
Type *ResultTy = Call.getType();
Check(
ValTy->isFPOrFPVectorTy() && ResultTy->isIntOrIntVectorTy(),
"llvm.lrint, llvm.llrint: argument must be floating-point or vector "
"of floating-points, and result must be integer or vector of integers",
&Call);
Check(ValTy->isVectorTy() == ResultTy->isVectorTy(),
"llvm.lrint, llvm.llrint: argument and result disagree on vector use",
&Call);
if (ValTy->isVectorTy()) {
Check(cast<VectorType>(ValTy)->getElementCount() ==
cast<VectorType>(ResultTy)->getElementCount(),
"llvm.lrint, llvm.llrint: argument must be same length as result",
&Call);
}
break;
}
case Intrinsic::lround:
case Intrinsic::llround: {
Type *ValTy = Call.getArgOperand(0)->getType();
Type *ResultTy = Call.getType();
Check(!ValTy->isVectorTy() && !ResultTy->isVectorTy(),
Expand Down
30 changes: 29 additions & 1 deletion llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
VT, Custom);
setOperationAction({ISD::FP_TO_SINT_SAT, ISD::FP_TO_UINT_SAT}, VT,
Custom);

setOperationAction({ISD::LRINT, ISD::LLRINT}, VT, Custom);
setOperationAction(
{ISD::SADDSAT, ISD::UADDSAT, ISD::SSUBSAT, ISD::USUBSAT}, VT, Legal);

Expand Down Expand Up @@ -2950,6 +2950,31 @@ lowerFTRUNC_FCEIL_FFLOOR_FROUND(SDValue Op, SelectionDAG &DAG,
DAG.getTargetConstant(FRM, DL, Subtarget.getXLenVT()));
}

// Expand vector LRINT and LLRINT by converting to the integer domain.
static SDValue lowerVectorXRINT(SDValue Op, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
MVT VT = Op.getSimpleValueType();
assert(VT.isVector() && "Unexpected type");

SDLoc DL(Op);
SDValue Src = Op.getOperand(0);
MVT ContainerVT = VT;

if (VT.isFixedLengthVector()) {
ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
Src = convertToScalableVector(ContainerVT, Src, DAG, Subtarget);
}

auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
SDValue Truncated =
DAG.getNode(RISCVISD::VFCVT_X_F_VL, DL, ContainerVT, Src, Mask, VL);

if (!VT.isFixedLengthVector())
return Truncated;

return convertFromScalableVector(VT, Truncated, DAG, Subtarget);
}

static SDValue
getVSlidedown(SelectionDAG &DAG, const RISCVSubtarget &Subtarget,
const SDLoc &DL, EVT VT, SDValue Merge, SDValue Op,
Expand Down Expand Up @@ -5978,6 +6003,9 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
case ISD::FROUND:
case ISD::FROUNDEVEN:
return lowerFTRUNC_FCEIL_FFLOOR_FROUND(Op, DAG, Subtarget);
case ISD::LRINT:
case ISD::LLRINT:
return lowerVectorXRINT(Op, DAG, Subtarget);
case ISD::VECREDUCE_ADD:
case ISD::VECREDUCE_UMAX:
case ISD::VECREDUCE_SMAX:
Expand Down
27 changes: 27 additions & 0 deletions llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,31 @@ static const CostTblEntry VectorIntrinsicCostTable[]{
{Intrinsic::rint, MVT::nxv2f64, 7},
{Intrinsic::rint, MVT::nxv4f64, 7},
{Intrinsic::rint, MVT::nxv8f64, 7},
{Intrinsic::lrint, MVT::v2i32, 1},
{Intrinsic::lrint, MVT::v4i32, 1},
{Intrinsic::lrint, MVT::v8i32, 1},
{Intrinsic::lrint, MVT::v16i32, 1},
{Intrinsic::lrint, MVT::nxv1i32, 1},
{Intrinsic::lrint, MVT::nxv2i32, 1},
{Intrinsic::lrint, MVT::nxv4i32, 1},
{Intrinsic::lrint, MVT::nxv8i32, 1},
{Intrinsic::lrint, MVT::nxv16i32, 1},
{Intrinsic::lrint, MVT::v2i64, 1},
{Intrinsic::lrint, MVT::v4i64, 1},
{Intrinsic::lrint, MVT::v8i64, 1},
{Intrinsic::lrint, MVT::v16i64, 1},
{Intrinsic::lrint, MVT::nxv1i64, 1},
{Intrinsic::lrint, MVT::nxv2i64, 1},
{Intrinsic::lrint, MVT::nxv4i64, 1},
{Intrinsic::lrint, MVT::nxv8i64, 1},
{Intrinsic::llrint, MVT::v2i64, 1},
{Intrinsic::llrint, MVT::v4i64, 1},
{Intrinsic::llrint, MVT::v8i64, 1},
{Intrinsic::llrint, MVT::v16i64, 1},
{Intrinsic::llrint, MVT::nxv1i64, 1},
{Intrinsic::llrint, MVT::nxv2i64, 1},
{Intrinsic::llrint, MVT::nxv4i64, 1},
{Intrinsic::llrint, MVT::nxv8i64, 1},
{Intrinsic::nearbyint, MVT::v2f32, 9},
{Intrinsic::nearbyint, MVT::v4f32, 9},
{Intrinsic::nearbyint, MVT::v8f32, 9},
Expand Down Expand Up @@ -1051,6 +1076,8 @@ RISCVTTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
case Intrinsic::floor:
case Intrinsic::trunc:
case Intrinsic::rint:
case Intrinsic::lrint:
case Intrinsic::llrint:
case Intrinsic::round:
case Intrinsic::roundeven: {
// These all use the same code.
Expand Down
Loading