[MLIR][Python] Add SCFIfOp Python binding
Current generated Python binding for the SCF dialect does not allow users to call IfOp to create if-else branches on their own. This PR sets up the default binding generation for scf.if operation to address this problem. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D121076
This commit is contained in:
parent
fd4d584d6b
commit
036088fd6e
|
|
@ -64,3 +64,44 @@ class ForOp:
|
|||
To obtain the loop-carried operands, use `iter_args`.
|
||||
"""
|
||||
return self.body.arguments[1:]
|
||||
|
||||
|
||||
class IfOp:
|
||||
"""Specialization for the SCF if op class."""
|
||||
|
||||
def __init__(self,
|
||||
cond,
|
||||
results_=[],
|
||||
*,
|
||||
hasElse=False,
|
||||
loc=None,
|
||||
ip=None):
|
||||
"""Creates an SCF `if` operation.
|
||||
|
||||
- `cond` is a MLIR value of 'i1' type to determine which regions of code will be executed.
|
||||
- `hasElse` determines whether the if operation has the else branch.
|
||||
"""
|
||||
operands = []
|
||||
operands.append(cond)
|
||||
results = []
|
||||
results.extend(results_)
|
||||
super().__init__(
|
||||
self.build_generic(
|
||||
regions=2,
|
||||
results=results,
|
||||
operands=operands,
|
||||
loc=loc,
|
||||
ip=ip))
|
||||
self.regions[0].blocks.append(*[])
|
||||
if hasElse:
|
||||
self.regions[1].blocks.append(*[])
|
||||
|
||||
@property
|
||||
def then_block(self):
|
||||
"""Returns the then block of the if operation."""
|
||||
return self.regions[0].blocks[0]
|
||||
|
||||
@property
|
||||
def else_block(self):
|
||||
"""Returns the else block of the if operation."""
|
||||
return self.regions[1].blocks[0]
|
||||
|
|
|
|||
|
|
@ -82,3 +82,58 @@ def testOpsAsArguments():
|
|||
# CHECK: iter_args(%{{.*}} = %[[ARGS]]#0, %{{.*}} = %[[ARGS]]#1)
|
||||
# CHECK: scf.yield %{{.*}}, %{{.*}}
|
||||
# CHECK: return
|
||||
|
||||
|
||||
@constructAndPrintInModule
|
||||
def testIfWithoutElse():
|
||||
bool = IntegerType.get_signless(1)
|
||||
i32 = IntegerType.get_signless(32)
|
||||
|
||||
@builtin.FuncOp.from_py_func(bool)
|
||||
def simple_if(cond):
|
||||
if_op = scf.IfOp(cond)
|
||||
with InsertionPoint(if_op.then_block):
|
||||
one = arith.ConstantOp(i32, 1)
|
||||
add = arith.AddIOp(one, one)
|
||||
scf.YieldOp([])
|
||||
return
|
||||
|
||||
|
||||
# CHECK: func @simple_if(%[[ARG0:.*]]: i1)
|
||||
# CHECK: scf.if %[[ARG0:.*]]
|
||||
# CHECK: %[[ONE:.*]] = arith.constant 1
|
||||
# CHECK: %[[ADD:.*]] = arith.addi %[[ONE]], %[[ONE]]
|
||||
# CHECK: return
|
||||
|
||||
|
||||
@constructAndPrintInModule
|
||||
def testIfWithElse():
|
||||
bool = IntegerType.get_signless(1)
|
||||
i32 = IntegerType.get_signless(32)
|
||||
|
||||
@builtin.FuncOp.from_py_func(bool)
|
||||
def simple_if_else(cond):
|
||||
if_op = scf.IfOp(cond, [i32, i32], hasElse=True)
|
||||
with InsertionPoint(if_op.then_block):
|
||||
x_true = arith.ConstantOp(i32, 0)
|
||||
y_true = arith.ConstantOp(i32, 1)
|
||||
scf.YieldOp([x_true, y_true])
|
||||
with InsertionPoint(if_op.else_block):
|
||||
x_false = arith.ConstantOp(i32, 2)
|
||||
y_false = arith.ConstantOp(i32, 3)
|
||||
scf.YieldOp([x_false, y_false])
|
||||
add = arith.AddIOp(if_op.results[0], if_op.results[1])
|
||||
return
|
||||
|
||||
|
||||
# CHECK: func @simple_if_else(%[[ARG0:.*]]: i1)
|
||||
# CHECK: %[[RET:.*]]:2 = scf.if %[[ARG0:.*]]
|
||||
# CHECK: %[[ZERO:.*]] = arith.constant 0
|
||||
# CHECK: %[[ONE:.*]] = arith.constant 1
|
||||
# CHECK: scf.yield %[[ZERO]], %[[ONE]]
|
||||
# CHECK: } else {
|
||||
# CHECK: %[[TWO:.*]] = arith.constant 2
|
||||
# CHECK: %[[THREE:.*]] = arith.constant 3
|
||||
# CHECK: scf.yield %[[TWO]], %[[THREE]]
|
||||
# CHECK: arith.addi %[[RET]]#0, %[[RET]]#1
|
||||
# CHECK: return
|
||||
|
|
|
|||
Loading…
Reference in New Issue