[spirv] Use owning module ref to avoid leaks and fix ASAN tests

Differential Revision: https://reviews.llvm.org/D83982
This commit is contained in:
Lei Zhang 2020-07-16 16:05:51 -04:00
parent cc1b9b680f
commit 2dd9e43579
2 changed files with 17 additions and 14 deletions

View File

@ -103,7 +103,7 @@ public:
LogicalResult deserialize(); LogicalResult deserialize();
/// Collects the final SPIR-V ModuleOp. /// Collects the final SPIR-V ModuleOp.
Optional<spirv::ModuleOp> collect(); spirv::OwningSPIRVModuleRef collect();
private: private:
//===--------------------------------------------------------------------===// //===--------------------------------------------------------------------===//
@ -111,7 +111,7 @@ private:
//===--------------------------------------------------------------------===// //===--------------------------------------------------------------------===//
/// Initializes the `module` ModuleOp in this deserializer instance. /// Initializes the `module` ModuleOp in this deserializer instance.
spirv::ModuleOp createModuleOp(); spirv::OwningSPIRVModuleRef createModuleOp();
/// Processes SPIR-V module header in `binary`. /// Processes SPIR-V module header in `binary`.
LogicalResult processHeader(); LogicalResult processHeader();
@ -425,7 +425,7 @@ private:
Location unknownLoc; Location unknownLoc;
/// The SPIR-V ModuleOp. /// The SPIR-V ModuleOp.
Optional<spirv::ModuleOp> module; spirv::OwningSPIRVModuleRef module;
/// The current function under construction. /// The current function under construction.
Optional<spirv::FuncOp> curFunction; Optional<spirv::FuncOp> curFunction;
@ -556,13 +556,15 @@ LogicalResult Deserializer::deserialize() {
return success(); return success();
} }
Optional<spirv::ModuleOp> Deserializer::collect() { return module; } spirv::OwningSPIRVModuleRef Deserializer::collect() {
return std::move(module);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Module structure // Module structure
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
spirv::ModuleOp Deserializer::createModuleOp() { spirv::OwningSPIRVModuleRef Deserializer::createModuleOp() {
OpBuilder builder(context); OpBuilder builder(context);
OperationState state(unknownLoc, spirv::ModuleOp::getOperationName()); OperationState state(unknownLoc, spirv::ModuleOp::getOperationName());
spirv::ModuleOp::build(builder, state); spirv::ModuleOp::build(builder, state);
@ -1912,10 +1914,10 @@ LogicalResult ControlFlowStructurizer::structurizeImpl() {
// Go through all ops and remap the operands. // Go through all ops and remap the operands.
auto remapOperands = [&](Operation *op) { auto remapOperands = [&](Operation *op) {
for (auto &operand : op->getOpOperands()) for (auto &operand : op->getOpOperands())
if (auto mappedOp = mapper.lookupOrNull(operand.get())) if (Value mappedOp = mapper.lookupOrNull(operand.get()))
operand.set(mappedOp); operand.set(mappedOp);
for (auto &succOp : op->getBlockOperands()) for (auto &succOp : op->getBlockOperands())
if (auto mappedOp = mapper.lookupOrNull(succOp.get())) if (Block *mappedOp = mapper.lookupOrNull(succOp.get()))
succOp.set(mappedOp); succOp.set(mappedOp);
}; };
for (auto &block : body) { for (auto &block : body) {
@ -2354,7 +2356,7 @@ Deserializer::processOp<spirv::EntryPointOp>(ArrayRef<uint32_t> words) {
return emitError(unknownLoc, return emitError(unknownLoc,
"missing Execution Model specification in OpEntryPoint"); "missing Execution Model specification in OpEntryPoint");
} }
auto exec_model = opBuilder.getI32IntegerAttr(words[wordIndex++]); auto execModel = opBuilder.getI32IntegerAttr(words[wordIndex++]);
if (wordIndex >= words.size()) { if (wordIndex >= words.size()) {
return emitError(unknownLoc, "missing <id> in OpEntryPoint"); return emitError(unknownLoc, "missing <id> in OpEntryPoint");
} }
@ -2382,7 +2384,7 @@ Deserializer::processOp<spirv::EntryPointOp>(ArrayRef<uint32_t> words) {
interface.push_back(opBuilder.getSymbolRefAttr(arg.getOperation())); interface.push_back(opBuilder.getSymbolRefAttr(arg.getOperation()));
wordIndex++; wordIndex++;
} }
opBuilder.create<spirv::EntryPointOp>(unknownLoc, exec_model, opBuilder.create<spirv::EntryPointOp>(unknownLoc, execModel,
opBuilder.getSymbolRefAttr(fnName), opBuilder.getSymbolRefAttr(fnName),
opBuilder.getArrayAttr(interface)); opBuilder.getArrayAttr(interface));
return success(); return success();
@ -2594,5 +2596,5 @@ spirv::OwningSPIRVModuleRef spirv::deserialize(ArrayRef<uint32_t> binary,
if (failed(deserializer.deserialize())) if (failed(deserializer.deserialize()))
return nullptr; return nullptr;
return deserializer.collect().getValueOr(nullptr); return deserializer.collect();
} }

View File

@ -15,6 +15,7 @@
#include "mlir/Dialect/SPIRV/SPIRVAttributes.h" #include "mlir/Dialect/SPIRV/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h" #include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h"
#include "mlir/Dialect/SPIRV/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/SPIRVModule.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h" #include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/SPIRVTypes.h" #include "mlir/Dialect/SPIRV/SPIRVTypes.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
@ -56,7 +57,7 @@ protected:
} }
Type getFloatStructType() { Type getFloatStructType() {
OpBuilder opBuilder(module.body()); OpBuilder opBuilder(module->body());
llvm::SmallVector<Type, 1> elementTypes{opBuilder.getF32Type()}; llvm::SmallVector<Type, 1> elementTypes{opBuilder.getF32Type()};
llvm::SmallVector<spirv::StructType::OffsetInfo, 1> offsetInfo{0}; llvm::SmallVector<spirv::StructType::OffsetInfo, 1> offsetInfo{0};
auto structType = spirv::StructType::get(elementTypes, offsetInfo); auto structType = spirv::StructType::get(elementTypes, offsetInfo);
@ -64,7 +65,7 @@ protected:
} }
void addGlobalVar(Type type, llvm::StringRef name) { void addGlobalVar(Type type, llvm::StringRef name) {
OpBuilder opBuilder(module.body()); OpBuilder opBuilder(module->body());
auto ptrType = spirv::PointerType::get(type, spirv::StorageClass::Uniform); auto ptrType = spirv::PointerType::get(type, spirv::StorageClass::Uniform);
opBuilder.create<spirv::GlobalVariableOp>( opBuilder.create<spirv::GlobalVariableOp>(
UnknownLoc::get(&context), TypeAttr::get(ptrType), UnknownLoc::get(&context), TypeAttr::get(ptrType),
@ -98,7 +99,7 @@ protected:
protected: protected:
MLIRContext context; MLIRContext context;
spirv::ModuleOp module; spirv::OwningSPIRVModuleRef module;
SmallVector<uint32_t, 0> binary; SmallVector<uint32_t, 0> binary;
}; };
@ -109,7 +110,7 @@ protected:
TEST_F(SerializationTest, BlockDecorationTest) { TEST_F(SerializationTest, BlockDecorationTest) {
auto structType = getFloatStructType(); auto structType = getFloatStructType();
addGlobalVar(structType, "var0"); addGlobalVar(structType, "var0");
ASSERT_TRUE(succeeded(spirv::serialize(module, binary))); ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary)));
auto hasBlockDecoration = [](spirv::Opcode opcode, auto hasBlockDecoration = [](spirv::Opcode opcode,
ArrayRef<uint32_t> operands) -> bool { ArrayRef<uint32_t> operands) -> bool {
if (opcode != spirv::Opcode::OpDecorate || operands.size() != 2) if (opcode != spirv::Opcode::OpDecorate || operands.size() != 2)