[MLIR][Python] Add SCFIfOp Python binding

Current generated Python binding for the SCF dialect does not allow
users to call IfOp to create if-else branches on their own.
This PR sets up the default binding generation for scf.if operation
to address this problem.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D121076
This commit is contained in:
chhzh123 2022-03-13 05:24:00 +00:00 committed by Mehdi Amini
parent fd4d584d6b
commit 036088fd6e
2 changed files with 96 additions and 0 deletions

View File

@ -64,3 +64,44 @@ class ForOp:
To obtain the loop-carried operands, use `iter_args`.
"""
return self.body.arguments[1:]
class IfOp:
"""Specialization for the SCF if op class."""
def __init__(self,
cond,
results_=[],
*,
hasElse=False,
loc=None,
ip=None):
"""Creates an SCF `if` operation.
- `cond` is a MLIR value of 'i1' type to determine which regions of code will be executed.
- `hasElse` determines whether the if operation has the else branch.
"""
operands = []
operands.append(cond)
results = []
results.extend(results_)
super().__init__(
self.build_generic(
regions=2,
results=results,
operands=operands,
loc=loc,
ip=ip))
self.regions[0].blocks.append(*[])
if hasElse:
self.regions[1].blocks.append(*[])
@property
def then_block(self):
"""Returns the then block of the if operation."""
return self.regions[0].blocks[0]
@property
def else_block(self):
"""Returns the else block of the if operation."""
return self.regions[1].blocks[0]

View File

@ -82,3 +82,58 @@ def testOpsAsArguments():
# CHECK: iter_args(%{{.*}} = %[[ARGS]]#0, %{{.*}} = %[[ARGS]]#1)
# CHECK: scf.yield %{{.*}}, %{{.*}}
# CHECK: return
@constructAndPrintInModule
def testIfWithoutElse():
bool = IntegerType.get_signless(1)
i32 = IntegerType.get_signless(32)
@builtin.FuncOp.from_py_func(bool)
def simple_if(cond):
if_op = scf.IfOp(cond)
with InsertionPoint(if_op.then_block):
one = arith.ConstantOp(i32, 1)
add = arith.AddIOp(one, one)
scf.YieldOp([])
return
# CHECK: func @simple_if(%[[ARG0:.*]]: i1)
# CHECK: scf.if %[[ARG0:.*]]
# CHECK: %[[ONE:.*]] = arith.constant 1
# CHECK: %[[ADD:.*]] = arith.addi %[[ONE]], %[[ONE]]
# CHECK: return
@constructAndPrintInModule
def testIfWithElse():
bool = IntegerType.get_signless(1)
i32 = IntegerType.get_signless(32)
@builtin.FuncOp.from_py_func(bool)
def simple_if_else(cond):
if_op = scf.IfOp(cond, [i32, i32], hasElse=True)
with InsertionPoint(if_op.then_block):
x_true = arith.ConstantOp(i32, 0)
y_true = arith.ConstantOp(i32, 1)
scf.YieldOp([x_true, y_true])
with InsertionPoint(if_op.else_block):
x_false = arith.ConstantOp(i32, 2)
y_false = arith.ConstantOp(i32, 3)
scf.YieldOp([x_false, y_false])
add = arith.AddIOp(if_op.results[0], if_op.results[1])
return
# CHECK: func @simple_if_else(%[[ARG0:.*]]: i1)
# CHECK: %[[RET:.*]]:2 = scf.if %[[ARG0:.*]]
# CHECK: %[[ZERO:.*]] = arith.constant 0
# CHECK: %[[ONE:.*]] = arith.constant 1
# CHECK: scf.yield %[[ZERO]], %[[ONE]]
# CHECK: } else {
# CHECK: %[[TWO:.*]] = arith.constant 2
# CHECK: %[[THREE:.*]] = arith.constant 3
# CHECK: scf.yield %[[TWO]], %[[THREE]]
# CHECK: arith.addi %[[RET]]#0, %[[RET]]#1
# CHECK: return