diff options
author | Eugene Zhulenev <ezhulenev@google.com> | 2022-08-05 13:52:18 -0700 |
---|---|---|
committer | Eugene Zhulenev <ezhulenev@google.com> | 2022-08-05 14:05:03 -0700 |
commit | 5f1c7e2cc5a3c07cbc2412e851a7283c1841f520 (patch) | |
tree | dcd912a37678075b1ddb3b1e6daad0e3d9a230a2 | |
parent | 3fa291fa925dad4bae215fbcbc64db5ce66f0d9f (diff) |
[mlir] Use SymbolTableCollection to lookup referenced symbol in AddressOfOp
Depends On D131285
Reviewed By: Mogball
Differential Revision: https://reviews.llvm.org/D131291
5 files changed, 18 insertions, 9 deletions
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index 7dcd48f0c5e8..996175139575 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -1031,10 +1031,10 @@ def LLVM_AddressOfOp : LLVM_Op<"mlir.addressof", let extraClassDeclaration = [{ /// Return the llvm.mlir.global operation that defined the value referenced /// here. - GlobalOp getGlobal(); + GlobalOp getGlobal(SymbolTableCollection &symbolTable); /// Return the llvm.func operation that is referenced here. - LLVMFuncOp getFunction(); + LLVMFuncOp getFunction(SymbolTableCollection &symbolTable); }]; let assemblyFormat = "$global_name attr-dict `:` type($res)"; diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h index f2b506692503..3dd0d769007c 100644 --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -16,6 +16,7 @@ #include "mlir/IR/Operation.h" #include "mlir/IR/Value.h" +#include "mlir/IR/SymbolTable.h" #include "mlir/Target/LLVMIR/Export.h" #include "mlir/Target/LLVMIR/LLVMTranslationInterface.h" #include "mlir/Target/LLVMIR/TypeToLLVM.h" @@ -264,6 +265,8 @@ public: ModuleTranslation &moduleTranslation; }; + SymbolTableCollection& symbolTable() { return symbolTableCollection; } + private: ModuleTranslation(Operation *module, std::unique_ptr<llvm::Module> llvmModule); @@ -333,6 +336,9 @@ private: /// Stack of user-specified state elements, useful when translating operations /// with regions. SmallVector<std::unique_ptr<StackFrame>> stack; + + /// A cache for the symbol tables constructed during symbols lookup. + SymbolTableCollection symbolTableCollection; }; namespace detail { diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 4cb6a5658c51..25a64900ad66 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1737,14 +1737,14 @@ static Operation *parentLLVMModule(Operation *op) { return module; } -GlobalOp AddressOfOp::getGlobal() { +GlobalOp AddressOfOp::getGlobal(SymbolTableCollection &symbolTable) { return dyn_cast_or_null<GlobalOp>( - SymbolTable::lookupSymbolIn(parentLLVMModule(*this), getGlobalName())); + symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr())); } -LLVMFuncOp AddressOfOp::getFunction() { +LLVMFuncOp AddressOfOp::getFunction(SymbolTableCollection &symbolTable) { return dyn_cast_or_null<LLVMFuncOp>( - SymbolTable::lookupSymbolIn(parentLLVMModule(*this), getGlobalName())); + symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr())); } LogicalResult diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp index bdc891b93baa..67535ce6a0c7 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -466,8 +466,10 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, // operation and store it in the MLIR-to-LLVM value mapping. This does not // emit any LLVM instruction. if (auto addressOfOp = dyn_cast<LLVM::AddressOfOp>(opInst)) { - LLVM::GlobalOp global = addressOfOp.getGlobal(); - LLVM::LLVMFuncOp function = addressOfOp.getFunction(); + LLVM::GlobalOp global = + addressOfOp.getGlobal(moduleTranslation.symbolTable()); + LLVM::LLVMFuncOp function = + addressOfOp.getFunction(moduleTranslation.symbolTable()); // The verifier should not have allowed this. assert((global || function) && diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 85ec47aae400..ba231d643f4d 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -1285,7 +1285,8 @@ convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder, return opInst.emitError("Addressing symbol not found"); LLVM::AddressOfOp addressOfOp = dyn_cast<LLVM::AddressOfOp>(symOp); - LLVM::GlobalOp global = addressOfOp.getGlobal(); + LLVM::GlobalOp global = + addressOfOp.getGlobal(moduleTranslation.symbolTable()); llvm::GlobalValue *globalValue = moduleTranslation.lookupGlobal(global); llvm::Value *data = builder.CreateBitCast(globalValue, builder.getInt8PtrTy()); |