@@ -83,7 +83,8 @@ bool LowerTypeVisitor::visitInstruction(SpirvInstruction *instr) {
83
83
}
84
84
break ;
85
85
}
86
- // Variables must have a pointer type.
86
+ // Variables and function parameters must have a pointer type.
87
+ case spv::Op::OpFunctionParameter:
87
88
case spv::Op::OpVariable: {
88
89
const SpirvType *pointerType =
89
90
spvContext.getPointerType (resultType, instr->getStorageClass ());
@@ -141,18 +142,17 @@ const SpirvType *LowerTypeVisitor::lowerType(const SpirvType *type,
141
142
return spvContext.getSampledImageType (cast<ImageType>(imageSpirvType));
142
143
} else if (const auto *hybridFn = dyn_cast<HybridFunctionType>(type)) {
143
144
// Lower the return type.
144
- const QualType astReturnType = hybridFn->getAstReturnType ();
145
+ const QualType astReturnType = hybridFn->getReturnType ();
145
146
const SpirvType *spirvReturnType =
146
147
lowerType (astReturnType, rule, /* isRowMajor*/ llvm::None, loc);
147
148
148
- // Go over all params. If any of them is hybrid, lower it .
149
+ // Go over all params and lower them .
149
150
std::vector<const SpirvType *> paramTypes;
150
- for (auto *paramType : hybridFn->getParamTypes ()) {
151
- if (const auto *hybridParam = dyn_cast<HybridType>(paramType)) {
152
- paramTypes.push_back (lowerType (hybridParam, rule, loc));
153
- } else {
154
- paramTypes.push_back (paramType);
155
- }
151
+ for (auto paramType : hybridFn->getParamTypes ()) {
152
+ const auto *spirvParamType =
153
+ lowerType (paramType, rule, /* isRowMajor*/ llvm::None, loc);
154
+ paramTypes.push_back (spvContext.getPointerType (
155
+ spirvParamType, spv::StorageClass::Function));
156
156
}
157
157
158
158
return spvContext.getFunctionType (spirvReturnType, paramTypes);
@@ -169,9 +169,12 @@ const SpirvType *LowerTypeVisitor::lowerType(const SpirvType *type,
169
169
// sampledType in image types can only be numberical type.
170
170
// Sampler types cannot be further lowered.
171
171
// SampledImage types cannot be further lowered.
172
+ // FunctionType is not allowed to contain hybrid parameters or return type.
173
+ // StructType is not allowed to contain any hybrid types.
172
174
else if (isa<VoidType>(type) || isa<ScalarType>(type) ||
173
175
isa<MatrixType>(type) || isa<ImageType>(type) ||
174
- isa<SamplerType>(type) || isa<SampledImageType>(type)) {
176
+ isa<SamplerType>(type) || isa<SampledImageType>(type) ||
177
+ isa<FunctionType>(type) || isa<StructType>(type)) {
175
178
return type;
176
179
}
177
180
// Vectors could contain a hybrid type
@@ -204,11 +207,6 @@ const SpirvType *LowerTypeVisitor::lowerType(const SpirvType *type,
204
207
return raType;
205
208
return spvContext.getRuntimeArrayType (loweredElemType, raType->getStride ());
206
209
}
207
- // Struct types could contain a hybrid type
208
- else if (const auto *structType = dyn_cast<StructType>(type)) {
209
- // Struct types can not contain hybrid types.
210
- return structType;
211
- }
212
210
// Pointer types could point to a hybrid type.
213
211
else if (const auto *ptrType = dyn_cast<SpirvPointerType>(type)) {
214
212
const auto *loweredPointee =
@@ -220,26 +218,6 @@ const SpirvType *LowerTypeVisitor::lowerType(const SpirvType *type,
220
218
return spvContext.getPointerType (loweredPointee,
221
219
ptrType->getStorageClass ());
222
220
}
223
- // Function types may have a parameter or return type that is hybrid.
224
- else if (const auto *fnType = dyn_cast<FunctionType>(type)) {
225
- const auto *loweredRetType = lowerType (fnType->getReturnType (), rule, loc);
226
- bool wasLowered = fnType->getReturnType () != loweredRetType;
227
- llvm::SmallVector<const SpirvType *, 4 > loweredParams;
228
- const auto ¶mTypes = fnType->getParamTypes ();
229
- for (auto *paramType : paramTypes) {
230
- const auto *loweredParamType = lowerType (paramType, rule, loc);
231
- loweredParams.push_back (loweredParamType);
232
- if (loweredParamType != paramType) {
233
- wasLowered = true ;
234
- }
235
- }
236
- // If the function type didn't include any hybrid types, return itself.
237
- if (!wasLowered) {
238
- return fnType;
239
- }
240
-
241
- return spvContext.getFunctionType (loweredRetType, loweredParams);
242
- }
243
221
244
222
llvm_unreachable (" lowering of hybrid type not implemented" );
245
223
}
@@ -319,10 +297,10 @@ const SpirvType *LowerTypeVisitor::lowerType(QualType type,
319
297
// LowerTypeVisitor is invoked. We should error out if we encounter a
320
298
// literal type.
321
299
case BuiltinType::LitInt:
322
- // emitError("found literal int type when lowering types", srcLoc);
300
+ // emitError("found literal int type when lowering types", srcLoc);
323
301
return spvContext.getUIntType (64 );
324
302
case BuiltinType::LitFloat: {
325
- // emitError("found literal float type when lowering types", srcLoc);
303
+ // emitError("found literal float type when lowering types", srcLoc);
326
304
return spvContext.getFloatType (64 );
327
305
328
306
default :
0 commit comments