[spirv] Use owning module ref to avoid leaks and fix ASAN tests
Differential Revision: https://reviews.llvm.org/D83982
This commit is contained in:
parent
cc1b9b680f
commit
2dd9e43579
|
|
@ -103,7 +103,7 @@ public:
|
||||||
LogicalResult deserialize();
|
LogicalResult deserialize();
|
||||||
|
|
||||||
/// Collects the final SPIR-V ModuleOp.
|
/// Collects the final SPIR-V ModuleOp.
|
||||||
Optional<spirv::ModuleOp> collect();
|
spirv::OwningSPIRVModuleRef collect();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
//===--------------------------------------------------------------------===//
|
//===--------------------------------------------------------------------===//
|
||||||
|
|
@ -111,7 +111,7 @@ private:
|
||||||
//===--------------------------------------------------------------------===//
|
//===--------------------------------------------------------------------===//
|
||||||
|
|
||||||
/// Initializes the `module` ModuleOp in this deserializer instance.
|
/// Initializes the `module` ModuleOp in this deserializer instance.
|
||||||
spirv::ModuleOp createModuleOp();
|
spirv::OwningSPIRVModuleRef createModuleOp();
|
||||||
|
|
||||||
/// Processes SPIR-V module header in `binary`.
|
/// Processes SPIR-V module header in `binary`.
|
||||||
LogicalResult processHeader();
|
LogicalResult processHeader();
|
||||||
|
|
@ -425,7 +425,7 @@ private:
|
||||||
Location unknownLoc;
|
Location unknownLoc;
|
||||||
|
|
||||||
/// The SPIR-V ModuleOp.
|
/// The SPIR-V ModuleOp.
|
||||||
Optional<spirv::ModuleOp> module;
|
spirv::OwningSPIRVModuleRef module;
|
||||||
|
|
||||||
/// The current function under construction.
|
/// The current function under construction.
|
||||||
Optional<spirv::FuncOp> curFunction;
|
Optional<spirv::FuncOp> curFunction;
|
||||||
|
|
@ -556,13 +556,15 @@ LogicalResult Deserializer::deserialize() {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
Optional<spirv::ModuleOp> Deserializer::collect() { return module; }
|
spirv::OwningSPIRVModuleRef Deserializer::collect() {
|
||||||
|
return std::move(module);
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Module structure
|
// Module structure
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
spirv::ModuleOp Deserializer::createModuleOp() {
|
spirv::OwningSPIRVModuleRef Deserializer::createModuleOp() {
|
||||||
OpBuilder builder(context);
|
OpBuilder builder(context);
|
||||||
OperationState state(unknownLoc, spirv::ModuleOp::getOperationName());
|
OperationState state(unknownLoc, spirv::ModuleOp::getOperationName());
|
||||||
spirv::ModuleOp::build(builder, state);
|
spirv::ModuleOp::build(builder, state);
|
||||||
|
|
@ -1912,10 +1914,10 @@ LogicalResult ControlFlowStructurizer::structurizeImpl() {
|
||||||
// Go through all ops and remap the operands.
|
// Go through all ops and remap the operands.
|
||||||
auto remapOperands = [&](Operation *op) {
|
auto remapOperands = [&](Operation *op) {
|
||||||
for (auto &operand : op->getOpOperands())
|
for (auto &operand : op->getOpOperands())
|
||||||
if (auto mappedOp = mapper.lookupOrNull(operand.get()))
|
if (Value mappedOp = mapper.lookupOrNull(operand.get()))
|
||||||
operand.set(mappedOp);
|
operand.set(mappedOp);
|
||||||
for (auto &succOp : op->getBlockOperands())
|
for (auto &succOp : op->getBlockOperands())
|
||||||
if (auto mappedOp = mapper.lookupOrNull(succOp.get()))
|
if (Block *mappedOp = mapper.lookupOrNull(succOp.get()))
|
||||||
succOp.set(mappedOp);
|
succOp.set(mappedOp);
|
||||||
};
|
};
|
||||||
for (auto &block : body) {
|
for (auto &block : body) {
|
||||||
|
|
@ -2354,7 +2356,7 @@ Deserializer::processOp<spirv::EntryPointOp>(ArrayRef<uint32_t> words) {
|
||||||
return emitError(unknownLoc,
|
return emitError(unknownLoc,
|
||||||
"missing Execution Model specification in OpEntryPoint");
|
"missing Execution Model specification in OpEntryPoint");
|
||||||
}
|
}
|
||||||
auto exec_model = opBuilder.getI32IntegerAttr(words[wordIndex++]);
|
auto execModel = opBuilder.getI32IntegerAttr(words[wordIndex++]);
|
||||||
if (wordIndex >= words.size()) {
|
if (wordIndex >= words.size()) {
|
||||||
return emitError(unknownLoc, "missing <id> in OpEntryPoint");
|
return emitError(unknownLoc, "missing <id> in OpEntryPoint");
|
||||||
}
|
}
|
||||||
|
|
@ -2382,7 +2384,7 @@ Deserializer::processOp<spirv::EntryPointOp>(ArrayRef<uint32_t> words) {
|
||||||
interface.push_back(opBuilder.getSymbolRefAttr(arg.getOperation()));
|
interface.push_back(opBuilder.getSymbolRefAttr(arg.getOperation()));
|
||||||
wordIndex++;
|
wordIndex++;
|
||||||
}
|
}
|
||||||
opBuilder.create<spirv::EntryPointOp>(unknownLoc, exec_model,
|
opBuilder.create<spirv::EntryPointOp>(unknownLoc, execModel,
|
||||||
opBuilder.getSymbolRefAttr(fnName),
|
opBuilder.getSymbolRefAttr(fnName),
|
||||||
opBuilder.getArrayAttr(interface));
|
opBuilder.getArrayAttr(interface));
|
||||||
return success();
|
return success();
|
||||||
|
|
@ -2594,5 +2596,5 @@ spirv::OwningSPIRVModuleRef spirv::deserialize(ArrayRef<uint32_t> binary,
|
||||||
if (failed(deserializer.deserialize()))
|
if (failed(deserializer.deserialize()))
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
||||||
return deserializer.collect().getValueOr(nullptr);
|
return deserializer.collect();
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@
|
||||||
#include "mlir/Dialect/SPIRV/SPIRVAttributes.h"
|
#include "mlir/Dialect/SPIRV/SPIRVAttributes.h"
|
||||||
#include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h"
|
#include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h"
|
||||||
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
|
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
|
||||||
|
#include "mlir/Dialect/SPIRV/SPIRVModule.h"
|
||||||
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
|
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
|
||||||
#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
|
#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
|
|
@ -56,7 +57,7 @@ protected:
|
||||||
}
|
}
|
||||||
|
|
||||||
Type getFloatStructType() {
|
Type getFloatStructType() {
|
||||||
OpBuilder opBuilder(module.body());
|
OpBuilder opBuilder(module->body());
|
||||||
llvm::SmallVector<Type, 1> elementTypes{opBuilder.getF32Type()};
|
llvm::SmallVector<Type, 1> elementTypes{opBuilder.getF32Type()};
|
||||||
llvm::SmallVector<spirv::StructType::OffsetInfo, 1> offsetInfo{0};
|
llvm::SmallVector<spirv::StructType::OffsetInfo, 1> offsetInfo{0};
|
||||||
auto structType = spirv::StructType::get(elementTypes, offsetInfo);
|
auto structType = spirv::StructType::get(elementTypes, offsetInfo);
|
||||||
|
|
@ -64,7 +65,7 @@ protected:
|
||||||
}
|
}
|
||||||
|
|
||||||
void addGlobalVar(Type type, llvm::StringRef name) {
|
void addGlobalVar(Type type, llvm::StringRef name) {
|
||||||
OpBuilder opBuilder(module.body());
|
OpBuilder opBuilder(module->body());
|
||||||
auto ptrType = spirv::PointerType::get(type, spirv::StorageClass::Uniform);
|
auto ptrType = spirv::PointerType::get(type, spirv::StorageClass::Uniform);
|
||||||
opBuilder.create<spirv::GlobalVariableOp>(
|
opBuilder.create<spirv::GlobalVariableOp>(
|
||||||
UnknownLoc::get(&context), TypeAttr::get(ptrType),
|
UnknownLoc::get(&context), TypeAttr::get(ptrType),
|
||||||
|
|
@ -98,7 +99,7 @@ protected:
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
MLIRContext context;
|
MLIRContext context;
|
||||||
spirv::ModuleOp module;
|
spirv::OwningSPIRVModuleRef module;
|
||||||
SmallVector<uint32_t, 0> binary;
|
SmallVector<uint32_t, 0> binary;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -109,7 +110,7 @@ protected:
|
||||||
TEST_F(SerializationTest, BlockDecorationTest) {
|
TEST_F(SerializationTest, BlockDecorationTest) {
|
||||||
auto structType = getFloatStructType();
|
auto structType = getFloatStructType();
|
||||||
addGlobalVar(structType, "var0");
|
addGlobalVar(structType, "var0");
|
||||||
ASSERT_TRUE(succeeded(spirv::serialize(module, binary)));
|
ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary)));
|
||||||
auto hasBlockDecoration = [](spirv::Opcode opcode,
|
auto hasBlockDecoration = [](spirv::Opcode opcode,
|
||||||
ArrayRef<uint32_t> operands) -> bool {
|
ArrayRef<uint32_t> operands) -> bool {
|
||||||
if (opcode != spirv::Opcode::OpDecorate || operands.size() != 2)
|
if (opcode != spirv::Opcode::OpDecorate || operands.size() != 2)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue