add fallback_func for static graph compiler

This commit is contained in:
Dun Liang 2024-01-03 03:57:14 +08:00
parent 4b9d777570
commit 2744992946
1 changed files with 4 additions and 2 deletions

View File

@ -84,7 +84,7 @@ class CachedGraph:
exec_called = jt.flags.exec_called
self.outputs = func(*args, **kw)
import gc; gc.collect()
assert exec_called == jt.flags.exec_called
assert exec_called == jt.flags.exec_called, (exec_called, jt.flags.exec_called)
self.outputs_parsed = dfs(self.outputs)
self.outputs_var = [ v for _, v in self.outputs_parsed ]
self.inputs_parsed = dfs(self.inputs)
@ -103,9 +103,11 @@ class CachedGraph:
# 3. var path changed
# graph key:
# (args, kw), [ (var_path, shape dim, dtype), var ]
def build(func, debug=False):
def build(func, debug=False, fallback_func=None):
cache = {}
def func_wrapper(*args, **kw):
if fallback_func and fallback_func(*args, **kw):
return func(*args, **kw)
inputs = (args, kw)
config_key = str(dfs_config(inputs))
inputs_parsed = dfs(inputs)