[fir] TargetRewrite: Rewrite fir.address_of(func)
Rewrite AddrOfOp if taking the address of a function. Differential Revision: https://reviews.llvm.org/D114925 Co-authored-by: Eric Schweitz <eschweitz@nvidia.com>
This commit is contained in:
		
							parent
							
								
									867cd948ac
								
							
						
					
					
						commit
						3fd250d258
					
				| 
						 | 
					@ -100,6 +100,10 @@ public:
 | 
				
			||||||
      } else if (auto dispatch = dyn_cast<DispatchOp>(op)) {
 | 
					      } else if (auto dispatch = dyn_cast<DispatchOp>(op)) {
 | 
				
			||||||
        if (!hasPortableSignature(dispatch.getFunctionType()))
 | 
					        if (!hasPortableSignature(dispatch.getFunctionType()))
 | 
				
			||||||
          convertCallOp(dispatch);
 | 
					          convertCallOp(dispatch);
 | 
				
			||||||
 | 
					      } else if (auto addr = dyn_cast<AddrOfOp>(op)) {
 | 
				
			||||||
 | 
					        if (addr.getType().isa<mlir::FunctionType>() &&
 | 
				
			||||||
 | 
					            !hasPortableSignature(addr.getType()))
 | 
				
			||||||
 | 
					          convertAddrOp(addr);
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
    });
 | 
					    });
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -319,6 +323,55 @@ public:
 | 
				
			||||||
        newInTys.push_back(std::get<mlir::Type>(tup));
 | 
					        newInTys.push_back(std::get<mlir::Type>(tup));
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /// Taking the address of a function. Modify the signature as needed.
 | 
				
			||||||
 | 
					  void convertAddrOp(AddrOfOp addrOp) {
 | 
				
			||||||
 | 
					    rewriter->setInsertionPoint(addrOp);
 | 
				
			||||||
 | 
					    auto addrTy = addrOp.getType().cast<mlir::FunctionType>();
 | 
				
			||||||
 | 
					    llvm::SmallVector<mlir::Type> newResTys;
 | 
				
			||||||
 | 
					    llvm::SmallVector<mlir::Type> newInTys;
 | 
				
			||||||
 | 
					    for (mlir::Type ty : addrTy.getResults()) {
 | 
				
			||||||
 | 
					      llvm::TypeSwitch<mlir::Type>(ty)
 | 
				
			||||||
 | 
					          .Case<fir::ComplexType>([&](fir::ComplexType ty) {
 | 
				
			||||||
 | 
					            lowerComplexSignatureRes(ty, newResTys, newInTys);
 | 
				
			||||||
 | 
					          })
 | 
				
			||||||
 | 
					          .Case<mlir::ComplexType>([&](mlir::ComplexType ty) {
 | 
				
			||||||
 | 
					            lowerComplexSignatureRes(ty, newResTys, newInTys);
 | 
				
			||||||
 | 
					          })
 | 
				
			||||||
 | 
					          .Default([&](mlir::Type ty) { newResTys.push_back(ty); });
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    llvm::SmallVector<mlir::Type> trailingInTys;
 | 
				
			||||||
 | 
					    for (mlir::Type ty : addrTy.getInputs()) {
 | 
				
			||||||
 | 
					      llvm::TypeSwitch<mlir::Type>(ty)
 | 
				
			||||||
 | 
					          .Case<BoxCharType>([&](BoxCharType box) {
 | 
				
			||||||
 | 
					            if (noCharacterConversion) {
 | 
				
			||||||
 | 
					              newInTys.push_back(box);
 | 
				
			||||||
 | 
					            } else {
 | 
				
			||||||
 | 
					              for (auto &tup : specifics->boxcharArgumentType(box.getEleTy())) {
 | 
				
			||||||
 | 
					                auto attr = std::get<CodeGenSpecifics::Attributes>(tup);
 | 
				
			||||||
 | 
					                auto argTy = std::get<mlir::Type>(tup);
 | 
				
			||||||
 | 
					                llvm::SmallVector<mlir::Type> &vec =
 | 
				
			||||||
 | 
					                    attr.isAppend() ? trailingInTys : newInTys;
 | 
				
			||||||
 | 
					                vec.push_back(argTy);
 | 
				
			||||||
 | 
					              }
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					          })
 | 
				
			||||||
 | 
					          .Case<fir::ComplexType>([&](fir::ComplexType ty) {
 | 
				
			||||||
 | 
					            lowerComplexSignatureArg(ty, newInTys);
 | 
				
			||||||
 | 
					          })
 | 
				
			||||||
 | 
					          .Case<mlir::ComplexType>([&](mlir::ComplexType ty) {
 | 
				
			||||||
 | 
					            lowerComplexSignatureArg(ty, newInTys);
 | 
				
			||||||
 | 
					          })
 | 
				
			||||||
 | 
					          .Default([&](mlir::Type ty) { newInTys.push_back(ty); });
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    // append trailing input types
 | 
				
			||||||
 | 
					    newInTys.insert(newInTys.end(), trailingInTys.begin(), trailingInTys.end());
 | 
				
			||||||
 | 
					    // replace this op with a new one with the updated signature
 | 
				
			||||||
 | 
					    auto newTy = rewriter->getFunctionType(newInTys, newResTys);
 | 
				
			||||||
 | 
					    auto newOp =
 | 
				
			||||||
 | 
					        rewriter->create<AddrOfOp>(addrOp.getLoc(), newTy, addrOp.symbol());
 | 
				
			||||||
 | 
					    replaceOp(addrOp, newOp.getResult());
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /// Convert the type signatures on all the functions present in the module.
 | 
					  /// Convert the type signatures on all the functions present in the module.
 | 
				
			||||||
  /// As the type signature is being changed, this must also update the
 | 
					  /// As the type signature is being changed, this must also update the
 | 
				
			||||||
  /// function itself to use any new arguments, etc.
 | 
					  /// function itself to use any new arguments, etc.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -93,3 +93,13 @@ fir.global @name constant : !fir.char<1,9> {
 | 
				
			||||||
  //constant 1
 | 
					  //constant 1
 | 
				
			||||||
  fir.has_value %str : !fir.char<1,9>
 | 
					  fir.has_value %str : !fir.char<1,9>
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Test that we rewrite the fir.address_of operator
 | 
				
			||||||
 | 
					// INT32-LABEL: @addrof
 | 
				
			||||||
 | 
					// INT64-LABEL: @addrof
 | 
				
			||||||
 | 
					func @addrof() {
 | 
				
			||||||
 | 
					  // INT32: {{.*}} = fir.address_of(@boxcharcallee) : (!fir.ref<!fir.char<1,?>>, i32) -> ()
 | 
				
			||||||
 | 
					  // INT64: {{.*}} = fir.address_of(@boxcharcallee) : (!fir.ref<!fir.char<1,?>>, i64) -> ()
 | 
				
			||||||
 | 
					  %f = fir.address_of(@boxcharcallee) : (!fir.boxchar<1>) -> ()
 | 
				
			||||||
 | 
					  return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -452,3 +452,23 @@ func private @mlircomplexf32(%z1: complex<f32>, %z2: complex<f32>) -> complex<f3
 | 
				
			||||||
  // PPC: return [[RES]] : tuple<f32, f32>
 | 
					  // PPC: return [[RES]] : tuple<f32, f32>
 | 
				
			||||||
  return %0 : complex<f32>
 | 
					  return %0 : complex<f32>
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Test that we rewrite the fir.address_of operator.
 | 
				
			||||||
 | 
					// I32-LABEL: func @addrof()
 | 
				
			||||||
 | 
					// X64-LABEL: func @addrof()
 | 
				
			||||||
 | 
					// AARCH64-LABEL: func @addrof()
 | 
				
			||||||
 | 
					// PPC-LABEL: func @addrof()
 | 
				
			||||||
 | 
					func @addrof() {
 | 
				
			||||||
 | 
					  // I32: {{%.*}} = fir.address_of(@returncomplex4) : () -> i64
 | 
				
			||||||
 | 
					  // X64: {{%.*}} = fir.address_of(@returncomplex4) : () -> !fir.vector<2:!fir.real<4>>
 | 
				
			||||||
 | 
					  // AARCH64: {{%.*}} = fir.address_of(@returncomplex4) : () -> tuple<!fir.real<4>, !fir.real<4>>
 | 
				
			||||||
 | 
					  // PPC: {{%.*}} = fir.address_of(@returncomplex4) : () -> tuple<!fir.real<4>, !fir.real<4>>
 | 
				
			||||||
 | 
					  %r = fir.address_of(@returncomplex4) : () -> !fir.complex<4>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // I32: {{%.*}} = fir.address_of(@paramcomplex4) : (!fir.ref<tuple<!fir.real<4>, !fir.real<4>>>) -> ()
 | 
				
			||||||
 | 
					  // X64: {{%.*}} = fir.address_of(@paramcomplex4) : (!fir.vector<2:!fir.real<4>>) -> ()
 | 
				
			||||||
 | 
					  // AARCH64: {{%.*}} = fir.address_of(@paramcomplex4) : (!fir.array<2x!fir.real<4>>) -> ()
 | 
				
			||||||
 | 
					  // PPC: {{%.*}} = fir.address_of(@paramcomplex4) : (!fir.real<4>, !fir.real<4>) -> ()
 | 
				
			||||||
 | 
					  %p = fir.address_of(@paramcomplex4) : (!fir.complex<4>) -> ()
 | 
				
			||||||
 | 
					  return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue