mirror of https://github.com/Jittor/Jittor
add fallback_func for static graph compiler
This commit is contained in:
parent
4b9d777570
commit
2744992946
|
@ -84,7 +84,7 @@ class CachedGraph:
|
|||
exec_called = jt.flags.exec_called
|
||||
self.outputs = func(*args, **kw)
|
||||
import gc; gc.collect()
|
||||
assert exec_called == jt.flags.exec_called
|
||||
assert exec_called == jt.flags.exec_called, (exec_called, jt.flags.exec_called)
|
||||
self.outputs_parsed = dfs(self.outputs)
|
||||
self.outputs_var = [ v for _, v in self.outputs_parsed ]
|
||||
self.inputs_parsed = dfs(self.inputs)
|
||||
|
@ -103,9 +103,11 @@ class CachedGraph:
|
|||
# 3. var path changed
|
||||
# graph key:
|
||||
# (args, kw), [ (var_path, shape dim, dtype), var ]
|
||||
def build(func, debug=False):
|
||||
def build(func, debug=False, fallback_func=None):
|
||||
cache = {}
|
||||
def func_wrapper(*args, **kw):
|
||||
if fallback_func and fallback_func(*args, **kw):
|
||||
return func(*args, **kw)
|
||||
inputs = (args, kw)
|
||||
config_key = str(dfs_config(inputs))
|
||||
inputs_parsed = dfs(inputs)
|
||||
|
|
Loading…
Reference in New Issue