Skip to content

Commit 1fd8713

Browse files
committed
[spirv] Clean up usage of FunctionType.
1 parent da0e18b commit 1fd8713

9 files changed

+43
-79
lines changed

tools/clang/include/clang/SPIRV/SPIRVContext.h

+1-4
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,6 @@ struct FunctionTypeMapInfo {
129129
/// context is deleted. Therefore, this context should outlive the usages of the
130130
/// the SPIR-V entities allocated in memory.
131131
class SpirvContext {
132-
friend class SpirvBuilder;
133-
friend class EmitTypeHandler;
134-
135132
public:
136133
SpirvContext();
137134
~SpirvContext() = default;
@@ -197,7 +194,7 @@ class SpirvContext {
197194
FunctionType *getFunctionType(const SpirvType *ret,
198195
llvm::ArrayRef<const SpirvType *> param);
199196
HybridFunctionType *getFunctionType(QualType ret,
200-
llvm::ArrayRef<const SpirvType *> param);
197+
llvm::ArrayRef<QualType> param);
201198

202199
const StructType *getByteAddressBufferType(bool isWritable);
203200
const StructType *getACSBufferCounterType();

tools/clang/include/clang/SPIRV/SpirvBuilder.h

-2
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,6 @@ class SpirvBuilder {
6363
/// type in the current function and returns its pointer.
6464
SpirvFunctionParameter *addFnParam(QualType ptrType, SourceLocation,
6565
llvm::StringRef name = "");
66-
SpirvFunctionParameter *addFnParam(const SpirvType *ptrType, SourceLocation,
67-
llvm::StringRef name = "");
6866

6967
/// \brief Creates a local variable of the given type in the current
7068
/// function and returns it.

tools/clang/include/clang/SPIRV/SpirvType.h

+11-10
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,7 @@ class StructType : public SpirvType {
337337
StructInterfaceType interfaceType;
338338
};
339339

340+
/// Represents a SPIR-V pointer type.
340341
class SpirvPointerType : public SpirvType {
341342
public:
342343
SpirvPointerType(const SpirvType *pointee, spv::StorageClass sc)
@@ -356,11 +357,11 @@ class SpirvPointerType : public SpirvType {
356357
spv::StorageClass storageClass;
357358
};
358359

360+
/// Represents a SPIR-V function type. None of the parameters nor the return
361+
/// type is allowed to be a hybrid type.
359362
class FunctionType : public SpirvType {
360363
public:
361-
FunctionType(const SpirvType *ret, llvm::ArrayRef<const SpirvType *> param)
362-
: SpirvType(TK_Function), returnType(ret),
363-
paramTypes(param.begin(), param.end()) {}
364+
FunctionType(const SpirvType *ret, llvm::ArrayRef<const SpirvType *> param);
364365

365366
static bool classof(const SpirvType *t) {
366367
return t->getKind() == TK_Function;
@@ -484,24 +485,24 @@ class HybridSampledImageType : public HybridType {
484485
// This class can be extended to also accept QualType vector as param types.
485486
class HybridFunctionType : public HybridType {
486487
public:
487-
HybridFunctionType(QualType ret, llvm::ArrayRef<const SpirvType *> param)
488-
: HybridType(TK_HybridFunction), astReturnType(ret),
488+
HybridFunctionType(QualType ret, llvm::ArrayRef<QualType> param)
489+
: HybridType(TK_HybridFunction), returnType(ret),
489490
paramTypes(param.begin(), param.end()) {}
490491

491492
static bool classof(const SpirvType *t) {
492493
return t->getKind() == TK_HybridFunction;
493494
}
494495

495496
bool operator==(const HybridFunctionType &that) const {
496-
return astReturnType == that.astReturnType && paramTypes == that.paramTypes;
497+
return returnType == that.returnType && paramTypes == that.paramTypes;
497498
}
498499

499-
QualType getAstReturnType() const { return astReturnType; }
500-
llvm::ArrayRef<const SpirvType *> getParamTypes() const { return paramTypes; }
500+
QualType getReturnType() const { return returnType; }
501+
llvm::ArrayRef<QualType> getParamTypes() const { return paramTypes; }
501502

502503
private:
503-
QualType astReturnType;
504-
llvm::SmallVector<const SpirvType *, 8> paramTypes;
504+
QualType returnType;
505+
llvm::SmallVector<QualType, 8> paramTypes;
505506
};
506507

507508
} // end namespace spirv

tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

+1-3
Original file line numberDiff line numberDiff line change
@@ -535,11 +535,9 @@ SpirvInstruction *DeclResultIdMapper::getDeclEvalInfo(const ValueDecl *decl) {
535535
SpirvFunctionParameter *
536536
DeclResultIdMapper::createFnParam(const ParmVarDecl *param) {
537537
const auto type = getTypeOrFnRetType(param);
538-
const auto *ptrType =
539-
spvContext.getPointerType(type, spv::StorageClass::Function);
540538
const auto loc = param->getLocation();
541539
SpirvFunctionParameter *fnParamInstr =
542-
spvBuilder.addFnParam(ptrType, loc, param->getName());
540+
spvBuilder.addFnParam(type, loc, param->getName());
543541

544542
bool isAlias = false;
545543
(void)getTypeAndCreateCounterForPotentialAliasVar(param, &isAlias);

tools/clang/lib/SPIRV/LowerTypeVisitor.cpp

+15-37
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,8 @@ bool LowerTypeVisitor::visitInstruction(SpirvInstruction *instr) {
8383
}
8484
break;
8585
}
86-
// Variables must have a pointer type.
86+
// Variables and function parameters must have a pointer type.
87+
case spv::Op::OpFunctionParameter:
8788
case spv::Op::OpVariable: {
8889
const SpirvType *pointerType =
8990
spvContext.getPointerType(resultType, instr->getStorageClass());
@@ -141,18 +142,17 @@ const SpirvType *LowerTypeVisitor::lowerType(const SpirvType *type,
141142
return spvContext.getSampledImageType(cast<ImageType>(imageSpirvType));
142143
} else if (const auto *hybridFn = dyn_cast<HybridFunctionType>(type)) {
143144
// Lower the return type.
144-
const QualType astReturnType = hybridFn->getAstReturnType();
145+
const QualType astReturnType = hybridFn->getReturnType();
145146
const SpirvType *spirvReturnType =
146147
lowerType(astReturnType, rule, /*isRowMajor*/ llvm::None, loc);
147148

148-
// Go over all params. If any of them is hybrid, lower it.
149+
// Go over all params and lower them.
149150
std::vector<const SpirvType *> paramTypes;
150-
for (auto *paramType : hybridFn->getParamTypes()) {
151-
if (const auto *hybridParam = dyn_cast<HybridType>(paramType)) {
152-
paramTypes.push_back(lowerType(hybridParam, rule, loc));
153-
} else {
154-
paramTypes.push_back(paramType);
155-
}
151+
for (auto paramType : hybridFn->getParamTypes()) {
152+
const auto *spirvParamType =
153+
lowerType(paramType, rule, /*isRowMajor*/ llvm::None, loc);
154+
paramTypes.push_back(spvContext.getPointerType(
155+
spirvParamType, spv::StorageClass::Function));
156156
}
157157

158158
return spvContext.getFunctionType(spirvReturnType, paramTypes);
@@ -169,9 +169,12 @@ const SpirvType *LowerTypeVisitor::lowerType(const SpirvType *type,
169169
// sampledType in image types can only be numberical type.
170170
// Sampler types cannot be further lowered.
171171
// SampledImage types cannot be further lowered.
172+
// FunctionType is not allowed to contain hybrid parameters or return type.
173+
// StructType is not allowed to contain any hybrid types.
172174
else if (isa<VoidType>(type) || isa<ScalarType>(type) ||
173175
isa<MatrixType>(type) || isa<ImageType>(type) ||
174-
isa<SamplerType>(type) || isa<SampledImageType>(type)) {
176+
isa<SamplerType>(type) || isa<SampledImageType>(type) ||
177+
isa<FunctionType>(type) || isa<StructType>(type)) {
175178
return type;
176179
}
177180
// Vectors could contain a hybrid type
@@ -204,11 +207,6 @@ const SpirvType *LowerTypeVisitor::lowerType(const SpirvType *type,
204207
return raType;
205208
return spvContext.getRuntimeArrayType(loweredElemType, raType->getStride());
206209
}
207-
// Struct types could contain a hybrid type
208-
else if (const auto *structType = dyn_cast<StructType>(type)) {
209-
// Struct types can not contain hybrid types.
210-
return structType;
211-
}
212210
// Pointer types could point to a hybrid type.
213211
else if (const auto *ptrType = dyn_cast<SpirvPointerType>(type)) {
214212
const auto *loweredPointee =
@@ -220,26 +218,6 @@ const SpirvType *LowerTypeVisitor::lowerType(const SpirvType *type,
220218
return spvContext.getPointerType(loweredPointee,
221219
ptrType->getStorageClass());
222220
}
223-
// Function types may have a parameter or return type that is hybrid.
224-
else if (const auto *fnType = dyn_cast<FunctionType>(type)) {
225-
const auto *loweredRetType = lowerType(fnType->getReturnType(), rule, loc);
226-
bool wasLowered = fnType->getReturnType() != loweredRetType;
227-
llvm::SmallVector<const SpirvType *, 4> loweredParams;
228-
const auto &paramTypes = fnType->getParamTypes();
229-
for (auto *paramType : paramTypes) {
230-
const auto *loweredParamType = lowerType(paramType, rule, loc);
231-
loweredParams.push_back(loweredParamType);
232-
if (loweredParamType != paramType) {
233-
wasLowered = true;
234-
}
235-
}
236-
// If the function type didn't include any hybrid types, return itself.
237-
if (!wasLowered) {
238-
return fnType;
239-
}
240-
241-
return spvContext.getFunctionType(loweredRetType, loweredParams);
242-
}
243221

244222
llvm_unreachable("lowering of hybrid type not implemented");
245223
}
@@ -319,10 +297,10 @@ const SpirvType *LowerTypeVisitor::lowerType(QualType type,
319297
// LowerTypeVisitor is invoked. We should error out if we encounter a
320298
// literal type.
321299
case BuiltinType::LitInt:
322-
//emitError("found literal int type when lowering types", srcLoc);
300+
// emitError("found literal int type when lowering types", srcLoc);
323301
return spvContext.getUIntType(64);
324302
case BuiltinType::LitFloat: {
325-
//emitError("found literal float type when lowering types", srcLoc);
303+
// emitError("found literal float type when lowering types", srcLoc);
326304
return spvContext.getFloatType(64);
327305

328306
default:

tools/clang/lib/SPIRV/SPIRVContext.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -240,8 +240,7 @@ SpirvContext::getFunctionType(const SpirvType *ret,
240240
}
241241

242242
HybridFunctionType *
243-
SpirvContext::getFunctionType(QualType ret,
244-
llvm::ArrayRef<const SpirvType *> param) {
243+
SpirvContext::getFunctionType(QualType ret, llvm::ArrayRef<QualType> param) {
245244
return new (this) HybridFunctionType(ret, param);
246245
}
247246

tools/clang/lib/SPIRV/SPIRVEmitter.cpp

+4-9
Original file line numberDiff line numberDiff line change
@@ -964,7 +964,7 @@ void SPIRVEmitter::doFunctionDecl(const FunctionDecl *decl) {
964964
declIdMapper.getTypeAndCreateCounterForPotentialAliasVar(decl);
965965

966966
// Construct the function signature.
967-
llvm::SmallVector<const SpirvType *, 4> paramTypes;
967+
llvm::SmallVector<QualType, 4> paramTypes;
968968

969969
bool isNonStaticMemberFn = false;
970970
if (const auto *memberFn = dyn_cast<CXXMethodDecl>(decl)) {
@@ -975,18 +975,14 @@ void SPIRVEmitter::doFunctionDecl(const FunctionDecl *decl) {
975975
// object on which we are invoking this method.
976976
const QualType valueType =
977977
memberFn->getThisType(astContext)->getPointeeType();
978-
const SpirvType *ptrType =
979-
spvContext.getPointerType(valueType, spv::StorageClass::Function);
980-
paramTypes.push_back(ptrType);
978+
paramTypes.push_back(valueType);
981979
}
982980
}
983981

984982
for (const auto *param : decl->params()) {
985983
const QualType valueType =
986984
declIdMapper.getTypeAndCreateCounterForPotentialAliasVar(param);
987-
const SpirvType *ptrType =
988-
spvContext.getPointerType(valueType, spv::StorageClass::Function);
989-
paramTypes.push_back(ptrType);
985+
paramTypes.push_back(valueType);
990986
}
991987

992988
auto *funcType = spvContext.getFunctionType(retType, paramTypes);
@@ -8972,8 +8968,7 @@ bool SPIRVEmitter::emitEntryFunctionWrapper(const FunctionDecl *decl,
89728968
uint32_t outputArraySize = 0;
89738969

89748970
// Construct the wrapper function signature.
8975-
const SpirvType *voidType = spvContext.getVoidType();
8976-
FunctionType *funcType = spvContext.getFunctionType(voidType, {});
8971+
auto *funcType = spvContext.getFunctionType(astContext.VoidTy, {});
89778972

89788973
// The wrapper entry function surely does not have pre-assigned <result-id>
89798974
// for it like other functions that got added to the work queue following

tools/clang/lib/SPIRV/SpirvBuilder.cpp

-12
Original file line numberDiff line numberDiff line change
@@ -56,18 +56,6 @@ SpirvFunctionParameter *SpirvBuilder::addFnParam(QualType ptrType,
5656
return param;
5757
}
5858

59-
SpirvFunctionParameter *SpirvBuilder::addFnParam(const SpirvType *ptrType,
60-
SourceLocation loc,
61-
llvm::StringRef name) {
62-
assert(function && "found detached parameter");
63-
auto *param =
64-
new (context) SpirvFunctionParameter(/*QualType*/ {}, /*id*/ 0, loc);
65-
param->setResultType(ptrType);
66-
param->setDebugName(name);
67-
function->addParameter(param);
68-
return param;
69-
}
70-
7159
SpirvVariable *SpirvBuilder::addFnVar(QualType valueType, SourceLocation loc,
7260
llvm::StringRef name,
7361
SpirvInstruction *init) {

tools/clang/lib/SPIRV/SpirvType.cpp

+10
Original file line numberDiff line numberDiff line change
@@ -243,5 +243,15 @@ bool HybridStructType::operator==(const HybridStructType &that) const {
243243
readOnly == that.readOnly && interfaceType == that.interfaceType;
244244
}
245245

246+
FunctionType::FunctionType(const SpirvType *ret,
247+
llvm::ArrayRef<const SpirvType *> param)
248+
: SpirvType(TK_Function), returnType(ret),
249+
paramTypes(param.begin(), param.end()) {
250+
// Make sure
251+
assert(!isa<HybridType>(ret));
252+
for (auto *paramType : param)
253+
assert(!isa<HybridType>(param));
254+
}
255+
246256
} // namespace spirv
247257
} // namespace clang

0 commit comments

Comments
 (0)