mirror of https://github.com/Jittor/Jittor
677 lines
19 KiB
Python
677 lines
19 KiB
Python
# ***************************************************************
|
|
# Copyright (c) 2020 Jittor. Authors:
|
|
# Dun Liang <randonlang@gmail.com>.
|
|
# Meng-Hao Guo <guomenghao1997@gmail.com>
|
|
#
|
|
# All Rights Reserved.
|
|
# This file is subject to the terms and conditions defined in
|
|
# file 'LICENSE.txt', which is part of this source code package.
|
|
# ***************************************************************
|
|
from . import lock
|
|
with lock.lock_scope():
|
|
from . import compiler
|
|
from .compiler import LOG, has_cuda
|
|
from .compiler import compile_custom_ops, compile_custom_op
|
|
import jittor_core as core
|
|
from jittor_core import *
|
|
from jittor_core.ops import *
|
|
from . import compile_extern
|
|
from .compile_extern import mkl_ops
|
|
|
|
import contextlib
|
|
import numpy as np
|
|
from collections import OrderedDict
|
|
import types
|
|
import pickle
|
|
import sys
|
|
import traceback
|
|
|
|
def dfs(scope, vars):
|
|
for v in scope.children.values():
|
|
if type(v) == Scope:
|
|
dfs(v, vars)
|
|
else:
|
|
vars.append(v)
|
|
|
|
def dfs_records(scope, records):
|
|
for v in scope.children.values():
|
|
if type(v) == Scope:
|
|
dfs_records(v, records)
|
|
for v in scope.records.values():
|
|
records.append(v)
|
|
|
|
class Scope:
|
|
def __init__(self, parent=None, name=None):
|
|
self.children = OrderedDict()
|
|
self.index = {}
|
|
self.records = OrderedDict()
|
|
if name == None:
|
|
self.name = self.full_name = ""
|
|
else:
|
|
self.name = name
|
|
self.full_name = parent.full_name + name + "/"
|
|
|
|
def get_scope(self, name, unique=True):
|
|
if not unique:
|
|
index = self.index.get(name, 0)
|
|
self.index[name] = index+1
|
|
name = name + f'_{index}'
|
|
if name not in self.children:
|
|
sub_scope = Scope(self, name)
|
|
self.children[name] = sub_scope
|
|
else:
|
|
sub_scope = self.children[name]
|
|
assert type(sub_scope) == Scope, f"Name {name} is a Var: {sub_scope}"
|
|
return sub_scope
|
|
|
|
def make_var(self, shape, dtype, init, name, unique):
|
|
if not unique:
|
|
index = self.index.get(name, 0)
|
|
self.index[name] = index+1
|
|
name = name + f'_{index}'
|
|
if name in self.children:
|
|
var = self.children[name]
|
|
assert type(var) == core.Var, f"Name {name} exist: {var}"
|
|
assert (shape is None or var.shape == shape) and var.dtype == dtype, f"Shape or dtype not match {var} != {dtype}{shape}"
|
|
return var
|
|
else:
|
|
full_name = self.full_name + name
|
|
if type(init) != core.Var:
|
|
if callable(init):
|
|
var = init(shape, dtype)
|
|
if type(var) != core.Var:
|
|
var = array(var)
|
|
else:
|
|
assert init != None
|
|
var = array(init)
|
|
else:
|
|
var = init
|
|
var.stop_fuse()
|
|
self.children[name] = var
|
|
var.name(full_name)
|
|
return var
|
|
|
|
def clean_index(self): self.index.clear()
|
|
|
|
def clean(self):
|
|
self.children.clear()
|
|
self.records.clear()
|
|
self.index.clear()
|
|
|
|
current_scope = Scope()
|
|
root_scope = current_scope
|
|
|
|
class _call_record_scope:
|
|
def __enter__(self): pass
|
|
def __exit__(self, *exc): pass
|
|
def __call__(self, func):
|
|
def inner(*args, **kw):
|
|
with self:
|
|
ret = func(*args, **kw)
|
|
record_in_scope(ret, "output")
|
|
return ret
|
|
return inner
|
|
|
|
class _call_no_record_scope:
|
|
def __enter__(self): pass
|
|
def __exit__(self, *exc): pass
|
|
def __call__(self, func):
|
|
def inner(*args, **kw):
|
|
with self:
|
|
ret = func(*args, **kw)
|
|
return ret
|
|
return inner
|
|
|
|
class flag_scope(_call_no_record_scope):
|
|
def __init__(self, **jt_flags):
|
|
self.jt_flags = jt_flags
|
|
|
|
def __enter__(self):
|
|
flags_bk = self.flags_bk = {}
|
|
try:
|
|
for k,v in self.jt_flags.items():
|
|
flags_bk[k] = getattr(flags, k)
|
|
setattr(flags, k, v)
|
|
except:
|
|
self.__exit__()
|
|
raise
|
|
|
|
def __exit__(self, *exc):
|
|
for k,v in self.flags_bk.items():
|
|
setattr(flags, k, v)
|
|
|
|
class var_scope(_call_record_scope):
|
|
def __init__(self, name="scope", unique=False, **jt_flags):
|
|
self.fs = flag_scope(**jt_flags)
|
|
self.name = name
|
|
self.unique = unique
|
|
|
|
def __enter__(self):
|
|
global current_scope
|
|
self.prev = current_scope
|
|
try:
|
|
current_scope = current_scope.get_scope(self.name, self.unique)
|
|
current_scope.clean_index()
|
|
self.fs.__enter__()
|
|
except:
|
|
current_scope = self.prev
|
|
del self.prev
|
|
raise
|
|
|
|
def __exit__(self, *exc):
|
|
self.fs.__exit__(*exc)
|
|
global current_scope
|
|
current_scope = self.prev
|
|
del self.prev
|
|
|
|
single_log_capture = None
|
|
|
|
class log_capture_scope(_call_no_record_scope):
|
|
"""log capture scope
|
|
example:
|
|
with jt.log_capture_scope(log_v=0) as logs:
|
|
LOG.v("...")
|
|
print(logs)
|
|
"""
|
|
def __init__(self, **jt_flags):
|
|
self.fs = flag_scope(**jt_flags)
|
|
|
|
def __enter__(self):
|
|
global single_log_capture
|
|
assert not single_log_capture
|
|
single_log_capture = True
|
|
self.logs = []
|
|
LOG.log_capture_start()
|
|
try:
|
|
self.fs.__enter__()
|
|
return self.logs
|
|
except:
|
|
LOG.log_capture_stop()
|
|
single_log_capture = None
|
|
raise
|
|
|
|
def __exit__(self, *exc):
|
|
global single_log_capture
|
|
self.fs.__exit__(*exc)
|
|
LOG.log_capture_stop()
|
|
self.logs.extend(LOG.log_capture_read())
|
|
single_log_capture = None
|
|
|
|
|
|
class profile_scope(_call_no_record_scope):
|
|
""" profile scope
|
|
example:
|
|
with jt.profile_scope() as report:
|
|
......
|
|
print(report)
|
|
"""
|
|
def __init__(self, warmup=0, rerun=0, **jt_flags):
|
|
self.fs = flag_scope(**jt_flags)
|
|
self.warmup = warmup
|
|
self.rerun = rerun
|
|
|
|
def __enter__(self):
|
|
assert not flags.profiler_enable
|
|
profiler.start(self.warmup, self.rerun)
|
|
self.report = []
|
|
try:
|
|
self.fs.__enter__()
|
|
return self.report
|
|
except:
|
|
profiler.stop()
|
|
raise
|
|
|
|
def __exit__(self, *exc):
|
|
self.fs.__exit__(*exc)
|
|
profiler.stop()
|
|
self.report.extend(profiler.report())
|
|
|
|
def make_var(shape=None, dtype="float32", init=None, name='var', unique=False):
|
|
return current_scope.make_var(shape, dtype, init, name, unique)
|
|
|
|
def find_vars(path=None):
|
|
scope = current_scope
|
|
if path is not None:
|
|
assert isinstance(path, str)
|
|
ns = path.split("/")
|
|
if ns[-1] == "":
|
|
ns.pop()
|
|
for n in ns: scope = scope.children[n]
|
|
if not isinstance(scope, Scope):
|
|
return [scope]
|
|
vars = []
|
|
dfs(scope, vars)
|
|
return vars
|
|
|
|
def find_var(path):
|
|
scope = current_scope
|
|
if path is not None:
|
|
assert isinstance(path, str)
|
|
ns = path.split("/")
|
|
for n in ns: scope = scope.children[n]
|
|
assert not isinstance(scope, Scope)
|
|
return scope
|
|
|
|
def find_records(path=None):
|
|
scope = current_scope
|
|
if path is not None:
|
|
assert isinstance(path, str)
|
|
ns = path.split("/")
|
|
if ns[-1] == "":
|
|
ns.pop()
|
|
for n in ns: scope = scope.children[n]
|
|
assert isinstance(scope, Scope)
|
|
records = []
|
|
dfs_records(scope, records)
|
|
return records
|
|
|
|
def find_record(path):
|
|
scope = current_scope
|
|
assert isinstance(path, str)
|
|
ns = path.split("/")
|
|
for n in ns[:-1]: scope = scope.children[n]
|
|
assert isinstance(scope, Scope)
|
|
return scope.records[ns[-1]]
|
|
|
|
def find_scope(path):
|
|
scope = current_scope
|
|
if path is not None:
|
|
assert isinstance(path, str)
|
|
ns = path.split("/")
|
|
if ns[-1] == "":
|
|
ns.pop()
|
|
for n in ns: scope = scope.children[n]
|
|
assert isinstance(scope, Scope)
|
|
return scope
|
|
|
|
def record_in_scope(self, name):
|
|
current_scope.records[name] = self
|
|
if isinstance(self, Var):
|
|
full_name = current_scope.full_name + name
|
|
self.name(full_name)
|
|
return self
|
|
|
|
Var.record_in_scope = record_in_scope
|
|
|
|
def clean():
|
|
current_scope.clean()
|
|
import gc
|
|
# make sure python do a full collection
|
|
gc.collect()
|
|
|
|
cast = unary
|
|
|
|
def array(data, dtype=None):
|
|
if type(data) == core.Var:
|
|
if dtype is None:
|
|
return cast(data, data.dtype)
|
|
return cast(data, dtype)
|
|
if dtype != None:
|
|
return ops.array(np.array(data, dtype))
|
|
if type(data) == np.ndarray:
|
|
if data.flags.c_contiguous:
|
|
return ops.array(data)
|
|
else:
|
|
return ops.array(data.copy())
|
|
return ops.array(np.array(data))
|
|
|
|
def grad(loss, targets):
|
|
if type(targets) == core.Var:
|
|
return core.grad(loss, [targets])[0]
|
|
return core.grad(loss, targets)
|
|
|
|
def liveness_info():
|
|
return {
|
|
"hold_vars": core.number_of_hold_vars(),
|
|
"lived_vars": core.number_of_lived_vars(),
|
|
"lived_ops": core.number_of_lived_ops(),
|
|
}
|
|
|
|
def ones(shape, dtype="float32"):
|
|
return unary(1, dtype).broadcast(shape)
|
|
|
|
def zeros(shape, dtype="float32"):
|
|
return unary(0, dtype).broadcast(shape)
|
|
|
|
flags = core.flags()
|
|
|
|
def detach(x):
|
|
"""return detached var"""
|
|
return x.clone().stop_grad().clone()
|
|
Var.detach = detach
|
|
|
|
def view(x, *shape):
|
|
if isinstance(shape[0], tuple):
|
|
shape = shape[0]
|
|
return x.reshape(shape)
|
|
Var.view = view
|
|
|
|
def permute(x, *dim):
|
|
if isinstance(dim[0], tuple):
|
|
dim = dim[0]
|
|
return transpose(x, dim)
|
|
Var.permute = permute
|
|
|
|
def flatten(input, start_dim=0, end_dim=-1):
|
|
in_shape = input.shape
|
|
start_dim = len(in_shape) + start_dim if start_dim < 0 else start_dim
|
|
end_dim = len(in_shape) + end_dim if end_dim < 0 else end_dim
|
|
assert end_dim > start_dim, "end_dim should be larger than start_dim for flatten function"
|
|
out_shape = []
|
|
for i in range(0,start_dim,1): out_shape.append(in_shape[i])
|
|
dims = 1
|
|
for i in range(start_dim, end_dim+1, 1): dims *= in_shape[i]
|
|
out_shape.append(dims)
|
|
for i in range(end_dim+1,len(in_shape),1): out_shape.append(in_shape[i])
|
|
return input.reshape(out_shape)
|
|
Var.flatten = flatten
|
|
|
|
def detach_inplace(x):
|
|
return x.swap(x.stop_grad().clone())
|
|
Var.start_grad = Var.detach_inplace = detach_inplace
|
|
|
|
def unsqueeze(x, dim):
|
|
shape = list(x.shape)
|
|
assert dim <= len(shape)
|
|
return x.reshape(shape[:dim] + [1] + shape[dim:])
|
|
Var.unsqueeze = unsqueeze
|
|
|
|
def squeeze(x, dim):
|
|
shape = list(x.shape)
|
|
assert dim < len(shape)
|
|
assert shape[dim] == 1
|
|
return x.reshape(shape[:dim] + shape[dim+1:])
|
|
Var.squeeze = squeeze
|
|
|
|
def clamp(x, min_v, max_v):
|
|
# TODO: change to x.maximum(min_v).minimum(max_v)
|
|
assert min_v <= max_v
|
|
min_b = (x < min_v).int()
|
|
max_b = (x > max_v).int()
|
|
return x * (1 - min_b - max_b) + min_v * min_b + max_v * max_b
|
|
Var.clamp = clamp
|
|
|
|
def type_as(a, b):
|
|
return a.unary(op=b.dtype)
|
|
Var.type_as = type_as
|
|
|
|
def masked_fill(x, mask, value):
|
|
assert list(x.shape) == list(mask.shape)
|
|
# TODO: assert mask = 0 or 1
|
|
return x * (1 - mask) + mask * value
|
|
Var.masked_fill = masked_fill
|
|
|
|
|
|
def sqr(x): return x*x
|
|
Var.sqr = sqr
|
|
|
|
def attrs(var):
|
|
return {
|
|
"is_stop_fuse": var.is_stop_fuse(),
|
|
"is_stop_grad": var.is_stop_grad(),
|
|
"shape": var.shape,
|
|
"dtype": var.dtype,
|
|
}
|
|
Var.attrs = attrs
|
|
|
|
def fetch(vars, func, *args, **kw):
|
|
core.fetch(vars, lambda *results: func(*results, *args, **kw))
|
|
|
|
def fetch_var(var, func, *args, **kw):
|
|
core.fetch([var], lambda a: func(a, *args, **kw))
|
|
Var.fetch = fetch_var
|
|
del fetch_var
|
|
|
|
def import_vars(data):
|
|
''' Load variables into current scopes
|
|
example:
|
|
import_vars({"w":[1.0,2.0,3.0]})
|
|
jt.get_var([3], "float64", name="w", gen_index=False)
|
|
'''
|
|
for k in data:
|
|
v = data[k]
|
|
if type(v) != core.Var:
|
|
v = array(v).stop_fuse()
|
|
scopes = k.split("/")
|
|
scope = current_scope
|
|
for i in range(len(scopes)-1):
|
|
scope = scope.get_scope(scopes[i])
|
|
vname = scopes[-1]
|
|
assert vname not in scope.children, f"Var {k} exists. Please load_vars at the beginning"
|
|
v.name(k)
|
|
scope.children[vname] = v
|
|
|
|
def export_vars():
|
|
''' Export all vars into a dictionary
|
|
return: a dictionary, key is var name, value is numpy array
|
|
'''
|
|
data = { v.name():v.fetch_sync() for v in find_vars() }
|
|
return data
|
|
|
|
def load(path):
|
|
pkl_file = open(path, 'rb')
|
|
model_dict = pickle.load(pkl_file)
|
|
return model_dict
|
|
|
|
class Module:
|
|
def __init__(self, *args, **kw):
|
|
__doc__ == 'doc'
|
|
def execute(self, *args, **kw):
|
|
pass
|
|
def __call__(self, *args, **kw):
|
|
return self.execute(*args, **kw)
|
|
def __repr__(self):
|
|
return self.__str__()
|
|
def _get_name(self):
|
|
return self.__class__.__name__
|
|
def __doc__(self):
|
|
pass
|
|
def __name__(self):
|
|
pass
|
|
|
|
def dfs(self, parents, k, callback, callback_leave=None):
|
|
n_children = 0
|
|
for v in self.__dict__.values():
|
|
if isinstance(v, Module):
|
|
n_children += 1
|
|
ret = callback(parents, k, self, n_children)
|
|
if ret == False: return
|
|
for k,v in self.__dict__.items():
|
|
if not isinstance(v, Module):
|
|
continue
|
|
parents.append(self)
|
|
v.dfs(parents, k, callback, callback_leave)
|
|
parents.pop()
|
|
if callback_leave:
|
|
callback_leave(parents, k, self, n_children)
|
|
|
|
def __str__(self):
|
|
ss = []
|
|
def callback(parents, k, v, n):
|
|
# indent key:class_name(extra_repr)
|
|
k = f"{k}: " if k is not None else ""
|
|
s = f"{' '*(len(parents)*4)}{k}{v.__class__.__name__}"
|
|
if n:
|
|
s += '('
|
|
else:
|
|
s += f"({v.extra_repr()})"
|
|
ss.append(s)
|
|
def callback_leave(parents, k, v, n):
|
|
if n:
|
|
ss.append(' '*(len(parents)*4)+')')
|
|
self.dfs([], None, callback, callback_leave)
|
|
return "\n".join(ss)
|
|
|
|
def parameters(self):
|
|
ps = []
|
|
stack = []
|
|
def callback(parents, k, v, n):
|
|
stack.append(str(k))
|
|
for k2, p in v.__dict__.items():
|
|
if isinstance(p, Var):
|
|
ps.append(p)
|
|
p.name(".".join(stack[1:]+[str(k2)]))
|
|
def callback_leave(parents, k, v, n):
|
|
stack.pop()
|
|
self.dfs([], None, callback, callback_leave)
|
|
return ps
|
|
|
|
def modules(self):
|
|
ms = []
|
|
def callback(parents, k, v, n):
|
|
if isinstance(v, Module):
|
|
ms.append(v)
|
|
self.dfs([], None, callback, None)
|
|
return ms
|
|
|
|
def children(self):
|
|
cd = []
|
|
def callback(parents, k, v, n):
|
|
if len(parents) == 1 and isinstance(v, Module):
|
|
cd.append(v)
|
|
return False
|
|
self.dfs([], None, callback, None)
|
|
return cd
|
|
|
|
def extra_repr(self):
|
|
ss = []
|
|
n = len(self.__init__.__code__.co_varnames) - \
|
|
len(self.__init__.__defaults__)
|
|
for i, k in enumerate(self.__init__.__code__.co_varnames[1:]):
|
|
v = getattr(self, k) if hasattr(self, k) else None
|
|
if isinstance(v, Var): v = v.peek()
|
|
s = f"{k}={v}" if i >= n else str(v)
|
|
ss.append(s)
|
|
return ", ".join(ss)
|
|
|
|
def load_parameters(self, params):
|
|
for key in params.keys():
|
|
v = self
|
|
key_ = key.split('.')
|
|
end = 0
|
|
for k in key_:
|
|
if isinstance(v, nn.Sequential):
|
|
if np.int(k) >= len(v.layers):
|
|
end = 1
|
|
break
|
|
else:
|
|
v = v[np.int(k)]
|
|
else:
|
|
if hasattr(v, k):
|
|
v = getattr(v, k)
|
|
else:
|
|
end = 1
|
|
break
|
|
if end ==1:
|
|
# print(f'init {key} fail ...')
|
|
pass
|
|
else:
|
|
# print(f'init {key} success ...')
|
|
if isinstance(params[key], np.ndarray) or isinstance(params[key], list):
|
|
v.assign(array(params[key]))
|
|
elif isinstance(params[key], Var):
|
|
v.assign(params[key])
|
|
else:
|
|
v.assign(array(params[key].cpu( ).detach().numpy()))
|
|
def save(self, path):
|
|
params = self.parameters()
|
|
params_dict = {}
|
|
for p in params:
|
|
params_dict[p.name()] = p.data
|
|
with open(path, 'wb') as f:
|
|
pickle.dump(params_dict, f, pickle.HIGHEST_PROTOCOL)
|
|
|
|
def eval(self):
|
|
def callback(parents, k, v, n):
|
|
if isinstance(v, Module) and hasattr(v, "is_train"):
|
|
v.is_train = False
|
|
self.dfs([], None, callback, None)
|
|
|
|
# backup stop grad or not
|
|
if not hasattr(self, "backup_grad_state"):
|
|
self.backup_grad_state = {}
|
|
for p in self.parameters():
|
|
if id(p) not in self.backup_grad_state:
|
|
self.backup_grad_state[id(p)] = not p.is_stop_grad()
|
|
p.stop_grad()
|
|
|
|
def train(self):
|
|
def callback(parents, k, v, n):
|
|
if isinstance(v, Module) and hasattr(v, "is_train"):
|
|
v.is_train = True
|
|
self.dfs([], None, callback, None)
|
|
|
|
# backup stop grad or not
|
|
if hasattr(self, "backup_grad_state"):
|
|
for p in self.parameters():
|
|
if id(p) in self.backup_grad_state and self.backup_grad_state[id(p)]:
|
|
p.start_grad()
|
|
|
|
def make_module(func, exec_n_args=1):
|
|
class MakeModule(Module):
|
|
def __init__(self, *args, **kw):
|
|
self.args = args
|
|
self.kw = kw
|
|
self.__doc__ == 'doc'
|
|
def execute(self, *args):
|
|
return func(*args, *self.args, **self.kw)
|
|
def __str__(self):
|
|
return 'str'
|
|
def __repr__(self):
|
|
return self.__str__()
|
|
def extra_repr(self):
|
|
return ''
|
|
|
|
return MakeModule
|
|
|
|
|
|
def dirty_fix_pytorch_runtime_error():
|
|
''' This funtion should be called before pytorch.
|
|
Example:
|
|
import jittor as jt
|
|
jt.dirty_fix_pytorch_runtime_error()
|
|
import torch
|
|
'''
|
|
import os
|
|
os.RTLD_GLOBAL = os.RTLD_GLOBAL | os.RTLD_DEEPBIND
|
|
|
|
|
|
import atexit
|
|
|
|
class ExitHooks(object):
|
|
def __init__(self):
|
|
self.exit_code = None
|
|
self.exception = None
|
|
|
|
def hook(self):
|
|
self._orig_exit = sys.exit
|
|
sys.exit = self.exit
|
|
sys.excepthook = self.exc_handler
|
|
|
|
def exit(self, code=0):
|
|
self.exit_code = code
|
|
self._orig_exit(code)
|
|
|
|
def exc_handler(self, exc_type, exc, *args):
|
|
self.exception = exc
|
|
traceback.print_exception(exc_type, exc, *args)
|
|
|
|
hooks = ExitHooks()
|
|
hooks.hook()
|
|
|
|
def jittor_exit():
|
|
if hooks.exit_code is not None:
|
|
pass
|
|
elif hooks.exception is not None:
|
|
pass
|
|
else:
|
|
core.sync_all(True)
|
|
atexit.register(jittor_exit)
|
|
|
|
Var.__repr__ = Var.__str__ = lambda x: str(x.data)
|
|
Var.peek = lambda x: str(x.dtype)+str(x.shape)
|
|
|
|
from . import nn
|
|
from .nn import matmul
|
|
from . import contrib
|