forked from OSchip/llvm-project
[MLIR] Rework generate-test-checks.py to attach CHECK lines to the source (test) file.
Summary:
This patch adds --source flag to indicate the source file. Then it tries to find insert
points in the source file and insert corresponding checks at those places.
Example output from Tensorflow XLA:
// -----
// CHECK-LABEL: func @main.3(
// CHECK-SAME: %[[VAL_0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 : index},
// CHECK-SAME: %[[VAL_1:.*]]: memref<16xi8> {xla_lhlo.alloc = 0 : index, xla_lhlo.liveout = true}) {
// CHECK: %[[VAL_2:.*]] = constant 0 : index
// CHECK: %[[VAL_3:.*]] = constant 0 : index
// CHECK: %[[VAL_4:.*]] = std.view %[[VAL_1]]{{\[}}%[[VAL_3]]][] : memref<16xi8> to memref<2x2xf32>
// CHECK: "xla_lhlo.tanh"(%[[VAL_0]], %[[VAL_4]]) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK: return
// CHECK: }
func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
%res = "xla_hlo.tanh"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %res : tensor<2x2xf32>
}
Differential Revision: https://reviews.llvm.org/D81903
This commit is contained in:
parent
3f0c9c1634
commit
25b3806788
|
|
@ -56,6 +56,12 @@ class SSAVariableNamer:
|
||||||
def pop_name_scope(self):
|
def pop_name_scope(self):
|
||||||
self.scopes.pop()
|
self.scopes.pop()
|
||||||
|
|
||||||
|
def num_scopes(self):
|
||||||
|
return len(self.scopes)
|
||||||
|
|
||||||
|
def clear_counter(self):
|
||||||
|
self.name_counter = 0
|
||||||
|
|
||||||
|
|
||||||
# Process a line of input that has been split at each SSA identifier '%'.
|
# Process a line of input that has been split at each SSA identifier '%'.
|
||||||
def process_line(line_chunks, variable_namer):
|
def process_line(line_chunks, variable_namer):
|
||||||
|
|
@ -87,6 +93,22 @@ def process_line(line_chunks, variable_namer):
|
||||||
return output_line + '\n'
|
return output_line + '\n'
|
||||||
|
|
||||||
|
|
||||||
|
def process_source_lines(source_lines, note, args):
|
||||||
|
source_split_re = re.compile(args.source_delim_regex)
|
||||||
|
|
||||||
|
source_segments = [[]]
|
||||||
|
for line in source_lines:
|
||||||
|
if line == note:
|
||||||
|
continue
|
||||||
|
if line.find(args.check_prefix) != -1:
|
||||||
|
continue
|
||||||
|
if source_split_re.search(line):
|
||||||
|
source_segments.append([])
|
||||||
|
|
||||||
|
source_segments[-1].append(line + '\n')
|
||||||
|
return source_segments
|
||||||
|
|
||||||
|
|
||||||
# Pre-process a line of input to remove any character sequences that will be
|
# Pre-process a line of input to remove any character sequences that will be
|
||||||
# problematic with FileCheck.
|
# problematic with FileCheck.
|
||||||
def preprocess_line(line):
|
def preprocess_line(line):
|
||||||
|
|
@ -112,25 +134,51 @@ def main():
|
||||||
'--output',
|
'--output',
|
||||||
nargs='?',
|
nargs='?',
|
||||||
type=argparse.FileType('w'),
|
type=argparse.FileType('w'),
|
||||||
default=sys.stdout)
|
default=None)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'input',
|
'input',
|
||||||
nargs='?',
|
nargs='?',
|
||||||
type=argparse.FileType('r'),
|
type=argparse.FileType('r'),
|
||||||
default=sys.stdin)
|
default=sys.stdin)
|
||||||
|
parser.add_argument(
|
||||||
|
'--source', type=str,
|
||||||
|
help='Print each CHECK chunk before each delimeter line in the source'
|
||||||
|
'file, respectively. The delimeter lines are identified by '
|
||||||
|
'--source_delim_regex.')
|
||||||
|
parser.add_argument('--source_delim_regex', type=str, default='func @')
|
||||||
|
parser.add_argument(
|
||||||
|
'--starts_from_scope', type=int, default=1,
|
||||||
|
help='Omit the top specified level of content. For example, by default '
|
||||||
|
'it omits "module {"')
|
||||||
|
parser.add_argument('-i', '--inplace', action='store_true', default=False)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Open the given input file.
|
# Open the given input file.
|
||||||
input_lines = [l.rstrip() for l in args.input]
|
input_lines = [l.rstrip() for l in args.input]
|
||||||
args.input.close()
|
args.input.close()
|
||||||
|
|
||||||
output_lines = []
|
|
||||||
|
|
||||||
# Generate a note used for the generated check file.
|
# Generate a note used for the generated check file.
|
||||||
script_name = os.path.basename(__file__)
|
script_name = os.path.basename(__file__)
|
||||||
autogenerated_note = (ADVERT + 'utils/' + script_name)
|
autogenerated_note = (ADVERT + 'utils/' + script_name)
|
||||||
output_lines.append(autogenerated_note + '\n')
|
|
||||||
|
|
||||||
|
source_segments = None
|
||||||
|
if args.source:
|
||||||
|
source_segments = process_source_lines(
|
||||||
|
[l.rstrip() for l in open(args.source, 'r')],
|
||||||
|
autogenerated_note,
|
||||||
|
args
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.inplace:
|
||||||
|
assert args.output is None
|
||||||
|
output = open(args.source, 'w')
|
||||||
|
elif args.output is None:
|
||||||
|
output = sys.stdout
|
||||||
|
else:
|
||||||
|
output = args.output
|
||||||
|
|
||||||
|
output_segments = [[]]
|
||||||
# A map containing data used for naming SSA value names.
|
# A map containing data used for naming SSA value names.
|
||||||
variable_namer = SSAVariableNamer()
|
variable_namer = SSAVariableNamer()
|
||||||
for input_line in input_lines:
|
for input_line in input_lines:
|
||||||
|
|
@ -144,17 +192,25 @@ def main():
|
||||||
if is_block:
|
if is_block:
|
||||||
input_line = input_line.rsplit('//', 1)[0].rstrip()
|
input_line = input_line.rsplit('//', 1)[0].rstrip()
|
||||||
|
|
||||||
# Top-level operations are heuristically the operations at nesting level 1.
|
cur_level = variable_namer.num_scopes()
|
||||||
is_toplevel_op = (not is_block and input_line.startswith(' ') and
|
|
||||||
input_line[2] != ' ' and input_line[2] != '}')
|
|
||||||
|
|
||||||
# If the line starts with a '}', pop the last name scope.
|
# If the line starts with a '}', pop the last name scope.
|
||||||
if lstripped_input_line[0] == '}':
|
if lstripped_input_line[0] == '}':
|
||||||
variable_namer.pop_name_scope()
|
variable_namer.pop_name_scope()
|
||||||
|
cur_level = variable_namer.num_scopes()
|
||||||
|
|
||||||
# If the line ends with a '{', push a new name scope.
|
# If the line ends with a '{', push a new name scope.
|
||||||
if input_line[-1] == '{':
|
if input_line[-1] == '{':
|
||||||
variable_namer.push_name_scope()
|
variable_namer.push_name_scope()
|
||||||
|
if cur_level == args.starts_from_scope:
|
||||||
|
output_segments.append([])
|
||||||
|
|
||||||
|
# Omit lines at the near top level e.g. "module {".
|
||||||
|
if cur_level < args.starts_from_scope:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if len(output_segments[-1]) == 0:
|
||||||
|
variable_namer.clear_counter()
|
||||||
|
|
||||||
# Preprocess the input to remove any sequences that may be problematic with
|
# Preprocess the input to remove any sequences that may be problematic with
|
||||||
# FileCheck.
|
# FileCheck.
|
||||||
|
|
@ -164,7 +220,7 @@ def main():
|
||||||
ssa_split = input_line.split('%')
|
ssa_split = input_line.split('%')
|
||||||
|
|
||||||
# If this is a top-level operation use 'CHECK-LABEL', otherwise 'CHECK:'.
|
# If this is a top-level operation use 'CHECK-LABEL', otherwise 'CHECK:'.
|
||||||
if not is_toplevel_op or not ssa_split[0]:
|
if len(output_segments[-1]) != 0 or not ssa_split[0]:
|
||||||
output_line = '// ' + args.check_prefix + ': '
|
output_line = '// ' + args.check_prefix + ': '
|
||||||
# Pad to align with the 'LABEL' statements.
|
# Pad to align with the 'LABEL' statements.
|
||||||
output_line += (' ' * len('-LABEL'))
|
output_line += (' ' * len('-LABEL'))
|
||||||
|
|
@ -176,32 +232,40 @@ def main():
|
||||||
output_line += process_line(ssa_split[1:], variable_namer)
|
output_line += process_line(ssa_split[1:], variable_namer)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Append a newline to the output to separate the logical blocks.
|
|
||||||
output_lines.append('\n')
|
|
||||||
output_line = '// ' + args.check_prefix + '-LABEL: '
|
|
||||||
|
|
||||||
# Output the first line chunk that does not contain an SSA name for the
|
# Output the first line chunk that does not contain an SSA name for the
|
||||||
# label.
|
# label.
|
||||||
output_line += ssa_split[0] + '\n'
|
output_line = '// ' + args.check_prefix + '-LABEL: ' + ssa_split[0] + '\n'
|
||||||
|
|
||||||
# Process the rest of the input line on a separate check line.
|
# Process the rest of the input line on separate check lines.
|
||||||
if len(ssa_split) > 1:
|
for argument in ssa_split[1:]:
|
||||||
output_line += '// ' + args.check_prefix + '-SAME: '
|
output_line += '// ' + args.check_prefix + '-SAME: '
|
||||||
|
|
||||||
# Pad to align with the original position in the line.
|
# Pad to align with the original position in the line.
|
||||||
output_line += ' ' * len(ssa_split[0])
|
output_line += ' ' * len(ssa_split[0])
|
||||||
|
|
||||||
# Process the rest of the line.
|
# Process the rest of the line.
|
||||||
output_line += process_line(ssa_split[1:], variable_namer)
|
output_line += process_line([argument], variable_namer)
|
||||||
|
|
||||||
# Append the output line.
|
# Append the output line.
|
||||||
output_lines.append(output_line)
|
output_segments[-1].append(output_line)
|
||||||
|
|
||||||
|
output.write(autogenerated_note + '\n')
|
||||||
|
|
||||||
# Write the output.
|
# Write the output.
|
||||||
for output_line in output_lines:
|
if source_segments:
|
||||||
args.output.write(output_line)
|
assert len(output_segments) == len(source_segments)
|
||||||
args.output.write('\n')
|
for check_segment, source_segment in zip(output_segments, source_segments):
|
||||||
args.output.close()
|
for line in check_segment:
|
||||||
|
output.write(line)
|
||||||
|
for line in source_segment:
|
||||||
|
output.write(line)
|
||||||
|
else:
|
||||||
|
for segment in output_segments:
|
||||||
|
output.write('\n')
|
||||||
|
for output_line in segment:
|
||||||
|
output.write(output_line)
|
||||||
|
output.write('\n')
|
||||||
|
output.close()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue