LLM support with jtorch

This commit is contained in:
Dun Liang 2023-03-20 22:33:12 +08:00
parent b25e62f1bb
commit 90850d36ad
13 changed files with 208 additions and 27 deletions

View File

@ -87,7 +87,7 @@ def safeunpickle(path):
from jittor_utils.misc import download_url_to_local
download_url_to_local(path, base, compiler.ck_path, None)
path = fname
if path.endswith(".pth"):
if path.endswith(".pth") or path.endswith(".pt") or path.endswith(".bin") :
from jittor_utils.load_pytorch import load_pytorch
model_dict = load_pytorch(path)
return model_dict
@ -371,7 +371,8 @@ def array(data, dtype=None):
dtype = str(dtype)
elif callable(dtype):
dtype = dtype.__name__
ret = ops.array(np.array(data, dtype))
with jt.flag_scope(auto_convert_64_to_32=0):
ret = ops.array(np.array(data, dtype))
else:
ret = ops.array(data)
# TODO: move those code to core
@ -460,6 +461,10 @@ def ones(*shape, dtype="float32"):
shape = shape[0]
return unary(1, dtype).broadcast(shape)
def new_ones(x, size):
return ones(size, x.dtype)
Var.new_ones = new_ones
def ones_like(x):
''' Constructs a jittor Var with all elements set to 1 and shape same with x.
@ -487,6 +492,22 @@ def zeros(*shape, dtype="float32"):
shape = shape[0]
return unary(0, dtype).broadcast(shape)
def new_zeros(x, size):
return zeros(size, x.dtype)
Var.new_zeros = new_zeros
def empty(*shape, dtype="float32"):
if isinstance(shape, tuple) and isinstance(shape[-1], (str, NanoString)):
dtype = shape[-1]
shape = shape[:-1]
if isinstance(shape, tuple) and isinstance(shape[0], (Sequence, NanoVector)):
shape = shape[0]
return ops.empty(shape, dtype)
def new_empty(x, size):
return empty(size, x.dtype)
Var.new_empty = new_empty
def full(shape,val,dtype="float32"):
''' Constructs a jittor Var with all elements set to val.
@ -503,6 +524,10 @@ def full(shape,val,dtype="float32"):
shape = (shape,)
return unary(val, dtype).broadcast(shape)
def new_full(x, size, val):
return full(size, val, x.dtype)
Var.new_full = new_full
def full_like(x, val, dtype=None) -> Var:
''' Constructs a jittor Var with all elements set to val and shape same with x.
@ -739,8 +764,6 @@ Var.type_as = type_as
Var.astype = Var.cast
def masked_fill(x, mask, value):
assert list(x.shape) == list(mask.shape)
# TODO: assert mask = 0 or 1
return x * (1 - mask) + mask * value
Var.masked_fill = masked_fill
@ -1271,6 +1294,19 @@ class Module:
Loads the module's parameters from a dictionary.
'''
self.load_parameters(params)
def _load_from_state_dict(self, state, prefix="", *args, **kw):
if len(prefix):
new_state = {}
for k,v in state.items():
if k.startswith(prefix):
new_state[k[len(prefix):]] = v
state = new_state
self.load_state_dict(state)
def cuda(self):
flags.use_cuda = 1
return self
def modules(self) -> List:
''' Returns a list of sub-modules in the module recursively.
@ -1625,20 +1661,44 @@ Arguments of hook are defined as::
def __getattr__(self, key):
return object.__getattribute__(self, key)
def register_buffer(self, key, value):
object.__setattr__(self, key, value)
return value
def float64(self):
'''convert all parameters to float16'''
self._amp_level = 0
for p in self.parameters():
if p.dtype.is_float():
p.assign(p.float64())
return self
def float32(self):
'''convert all parameters to float16'''
self._amp_level = 0
for p in self.parameters():
if p.dtype.is_float():
p.assign(p.float32())
return self
def float16(self):
'''convert all parameters to float16'''
self._amp_level = 4
cls = self.__class__
cls.__call__ = cls.__half_call__
for p in self.parameters():
if p.dtype.is_float():
p.assign(p.float16())
return self
def __half_call__(self, *args, **kw):
amp_level = getattr(self, "_amp_level", -1)
if amp_level >= 0:
with flag_scope(amp_level=amp_level):
return self.execute(*args, **kw)
else:
return self.execute(*args, **kw)
def half(self):
'''convert all parameters to float16'''
return self.float16()
@ -1646,6 +1706,7 @@ Arguments of hook are defined as::
def float_auto(self):
'''convert all parameters to float16 or float32 automatically
by jt.flags.auto_mixed_precision_level and jt.flags.amp_reg'''
self._amp_level = -1
for p in self.parameters():
if p.dtype.is_float():
p.assign(p.float_auto())
@ -1890,12 +1951,10 @@ Var.size = size
def to_int(v):
assert v.dtype.is_int()
return v.item()
return ori_int(v.item())
def to_float(v):
assert v.dtype.is_float()
return v.item()
return ori_float(v.item())
def to_bool(v):
assert v.dtype.is_int() or v.dtype.is_bool()
@ -1937,19 +1996,44 @@ from . import nn
from . import attention
from . import lr_scheduler
from . import linalg
from .linalg import einsum
from .nn import matmul, \
bmm, bmm_transpose
bmm, bmm_transpose, \
baddbmm
from . import contrib
from . import numpy2cupy
from .contrib import concat
from .contrib import concat, cat
from .misc import *
from . import sparse
from . import optim
from . import dataset
from . import init
dtype = NanoString
import jittor_utils
for backend in jittor_utils.backends:
if hasattr(backend, "post_process"):
backend.post_process()
# impl x.func(...) -> func_(...)
args = {"x", "input", "self"}
_white_list = {"mul", "add", "sub"}
for k,v in list(Var.__dict__.items()):
if k.startswith("_"): continue
if k.endswith("_"): continue
if not callable(v): continue
if k not in _white_list:
if not hasattr(v, "__code__"): continue
conames = v.__code__.co_varnames
if len(conames) == 0: continue
arg_name = conames[0]
if arg_name not in args: continue
new_k = k+"_"
if hasattr(Var, new_k): continue
def inplace_wrapper(new_k, prev_func):
setattr(Var, new_k, lambda x, *args, **kw: x.assign(prev_func(x, *args, **kw)))
inplace_wrapper(new_k, v)

View File

@ -270,3 +270,5 @@ Example::
# s = jt.setitem(s, tuple(slices), a)
cdim += a.shape[dim]
return s
cat = concat

View File

@ -163,16 +163,20 @@ string process_acl(const string& src, const string& name, const map<string,strin
new_src = "#include <Python.h>\n#include <pystate.h>\n"+
replace(new_src, "op->do_run_after_prepare(jkl);",
R"({
auto state = _PyThreadState_UncheckedGet();
Py_BEGIN_ALLOW_THREADS
op->do_run_after_prepare(jkl);
if (!_PyThreadState_UncheckedGet()) {
PyEval_AcquireThread(state);
}
Py_END_ALLOW_THREADS
})");
}
if (name == "profiler.cc") {
new_src = token_replace_all(new_src, ".cc", ".tikcc");
}
// LOGir << name << (name == "pass_manager.cc");
if (name == "pass_manager.cc") {
LOGir << "replace" << name;
new_src = token_replace_all(new_src, "run_pass<FloatAtomicFixPass>();", "WTF");
}
// ????????
return new_src;
}
@ -201,13 +205,42 @@ void acl_jittor_op_compiler(string& filename, string& src, bool is_acl, string&
src = new_src;
new_src = token_replace_all(new_src, "atomicAdd(&$1,$2);", "$1=$1+$2;");
new_src = token_replace_all(new_src, "bool", "int8");
new_src = token_replace_all(new_src, "::numeric_min<float32>()", "-1e30");
new_src = token_replace_all(new_src, "::numeric_max<float32>()", "1e30");
// TODO: support max
// new_src = token_replace_all(new_src, "::max($1,$2);", "($1)>($2)?($1):($2);");
unordered_map<string,string> opmap = {
// {"::max","tikcc::scalar_max"},
{"::sqrtf", "tikcc::scalar_sqrt"}
};
auto ss = split(new_src, ";");
for (auto &s : ss) {
if (s.find("?") != string::npos) {
s = token_replace_all(s+";", "auto $1=$2?$3:$4;", "auto $1=$3;if (!($2)) $1=$4;");
}
if (s.find("::max") != string::npos) {
if (s.find("auto") == string::npos) {
s = token_replace_all(s+";", " $1=$4::max($2,$3);", " $1=$2;if ($2 < $3) $1=$3;");
} else {
s = token_replace_all(s+";", "auto $1=$4::max($2,$3);", "auto $1=$2;if ($2 < $3) $1=$3;");
}
}
for (auto& kv : opmap) {
if (s.find(kv.first) != string::npos) {
if (s.find("auto") == string::npos) {
// $1 = op($2) --> op($1, $2)
s = token_replace_all(s+";", " $1= "+kv.first+"($2);", kv.second+"($1, $2);");
} else {
// auto $1 = op($2) --> float32 $1; op($1, $2);
s = token_replace_all(s+";", "auto $1= "+kv.first+"($2);", "float32 $1; " + kv.second+"($1, $2);");
}
}
}
// s = token_replace_all(s+";", "auto $1=$2?$3:$4;", "auto $1=$3;if (!($2)) $1=$4;");
// s = token_replace_all(s+";", "auto $1=$2?$3:$4;", "auto $1=$3;if (!($2)) $1=$4;");
// if (s.find("::max") != string::npos) {
// s = token_replace_all(s+";", " $1= ::max($2);", "tikcc::scalar_max($1, $2);");
// }
}
new_src = join(ss, ";");
src = new_src;

View File

@ -112,6 +112,7 @@ def constant_(var, value=0.0):
'''
return var.assign(constant(var.shape, var.dtype, value))
Var.constant_ = constant_
fill = Var.fill_ = constant_
def zero(shape, dtype="float32"):
'''Generate zero Jittor Var.

View File

@ -438,6 +438,8 @@ def einsum(string, *args):
:return: return values depend on the input string kinds.
"""
import numpy as np_cpu
if string == "i,j->ij":
return args[0].broadcast((args[0].shape[0], args[1].shape[0]), dims=[1]).multiply(args[1])
def forward_code(np, data):
out = data["outputs"][0]
npout = np.einsum(string, *data["inputs"])

View File

@ -128,6 +128,25 @@ def __iter__(x):
return result.__iter__()
jt.Var.__iter__ = __iter__
def __contains__(x, key):
return bool((x == key).any())
jt.Var.__contains__ = __contains__
def new(x, *args):
if len(args) != 1 or isinstance(args[0], int):
return jt.empty(args, x.dtype)
return jt.array(args[0]).cast(x.dtype)
jt.Var.new = new
def __index__(x):
return int(x.item())
jt.Var.__index__ = __index__
def sort(input, dim=-1, descending=False, stable=False):
index, value = jt.argsort(input, dim, descending)
return value, index
jt.Var.sort = sort
def all(x, dim=()):
return x.all_(dim).bool()
jt.Var.all = all
@ -787,6 +806,9 @@ jt.Var.nonzero = nonzero
def arange(start=0, end=None, step=1,dtype=None):
if isinstance(start, jt.Var): start = start.item()
if isinstance(end, jt.Var): end = end.item()
if isinstance(step, jt.Var): step = step.item()
if end is None:
end,start = start,0
l = round((end-start)//step)+1
@ -2015,6 +2037,7 @@ def triu(input: jt.Var, diagonal:int=0) -> jt.Var:
mask = index[-2] <= index[-1] - diagonal
return input*mask
jt.Var.triu = triu
jt.Var.triu_ = lambda x: x.assign(x.triu())
def tril(input: jt.Var, diagonal:int=0) -> jt.Var:
''' Returns the lower triangular part of a matrix (2-D tensor) or batch of matrices input, the other elements of the result tensor out are set to 0.
@ -2039,6 +2062,7 @@ def tril(input: jt.Var, diagonal:int=0) -> jt.Var:
mask = index[-2] >= index[-1] - diagonal
return input*mask
jt.Var.tril = tril
jt.Var.tril_ = lambda x: x.assign(x.tril())
def all_equal(a: jt.Var, b: jt.Var) -> bool:
return (a == b).all().item()

View File

@ -73,6 +73,12 @@ Example::
assert len(a.shape) > 2 and len(b.shape) > 2
return matmul(a, b)
def baddbmm(input, batch1, batch2, beta=1, alpha=1):
res = bmm(batch1, batch2)
if alpha != 1: res = res * alpha
if beta == 0: return res
return beta * input + res
def matmul(a, b):
''' matrix multiply,
@ -1683,15 +1689,18 @@ class Embedding(Module):
[ 0.14941819 0.57047683 -1.3217674]
[ 0.14941819 0.57047683 -1.3217674]], dtype=float32)
'''
def __init__(self, num, dim):
self.num = num
self.dim = dim
self.weight = jt.init.gauss([num,dim],'float32').stop_grad()
def __init__(self, num_embeddings, embedding_dim, dtype="float32"):
self.num = num_embeddings
self.dim = embedding_dim
self.weight = jt.init.gauss([self.num, self.dim], dtype).stop_grad()
def execute(self, x):
res = self.weight[x.flatten()].reshape(x.shape + [self.dim])
res = self.weight[x]
return res
def embedding(input, weight):
return weight[input]
class PixelShuffle(Module):
def __init__(self, upscale_factor):
self.upscale_factor = upscale_factor
@ -2973,3 +2982,19 @@ class KLDivLoss(Module):
else:
loss = loss_pointwise
return loss
class Mish(Module):
def __init__(self, inplace=False):
'''
Applies the Mish function, element-wise.
reference: Mish - A Self Regularized Non-Monotonic Neural Activation Function.
'''
pass
def execute(self, x):
return x * jt.tanh(jt.softplus(x))
def mish(x, inplace=False):
return x * jt.tanh(jt.softplus(x))
def skip_init(module_cls, *args, **kw):
return module_cls(*args, **kw)

View File

@ -778,7 +778,7 @@ def compile_src(src, h, basename):
if submodule_info and "attrs" in submodule_info and "core_name" in submodule_info["attrs"]:
core_name = submodule_info["attrs"]["core_name"]
has_map = class_name in ["VarHolder", "NanoVector"]
has_seq = class_name == "NanoVector"
has_seq = class_name in ["VarHolder", "NanoVector"]
# add extra include to avoid compile error
src_code = ""
if include_name.endswith("var_slices.h"):

View File

@ -91,7 +91,7 @@ unordered_set<string> binary_ops = {
* [in] y: the second input, a python number or jt.Var.
*/
// @pybind(subtract, __sub__)
// @pybind(subtract, __sub__, sub)
"subtract",
/**
@ -106,7 +106,7 @@ unordered_set<string> binary_ops = {
* [in] y: the second input, a python number or jt.Var.
*/
// @pybind(multiply, __mul__)
// @pybind(multiply, __mul__, mul)
"multiply",
/**
@ -138,7 +138,7 @@ unordered_set<string> binary_ops = {
returns float value even if the dtype of input Vars are both integers.
@see jt.ops.floor_divide() for floor division.
*/
// @pybind(divide, __truediv__)
// @pybind(divide, __truediv__, div)
"divide",
/**

View File

@ -17,6 +17,7 @@
#include "ops/op_register.h"
#include "ops/getitem_op.h"
#include "ops/setitem_op.h"
#include "type/fp16_compute.h"
namespace jittor {
@ -184,6 +185,13 @@ ArrayArgs VarHolder::fetch_sync() {
return {var->mem_ptr, var->shape, var->dtype()};
}
inline static void cast_item_data(ItemData& data) {
auto* fp16 = (float16*)&data;
auto* fp32 = (float32*)&data;
fp32[0] = float32(fp16[0]);
data.dtype = ns_float32;
}
ItemData VarHolder::item() {
sync();
CHECK(var->num==1) << "Item var size should be 1, but got" << var->num;
@ -199,6 +207,8 @@ ItemData VarHolder::item() {
{
std::memcpy(&data.data, var->mem_ptr, dsize);
}
if (data.dtype == ns_float16)
cast_item_data(data);
return data;
}

View File

@ -260,7 +260,7 @@ struct VarHolder {
/**
* return the number of dimensions.
*/
// @pyjt(__get__ndim)
// @pyjt(__get__ndim, dim)
inline int ndim() {
return var->shape.size();
}

View File

@ -452,7 +452,7 @@ def to_tensor(pic):
return np.float32(img) * np.float32(1/255.0)
else:
return img
pil_to_tensor = to_tensor
def _to_jittor_array(pic):

View File

@ -211,7 +211,7 @@ def load_pytorch(fn_name):
global contents, deserialized_objects, loaded_storages
loaded_storages = {}
deserialized_objects = {}
if not fn_name.endswith(".pth"):
if not (fn_name.endswith(".pth") or fn_name.endswith(".pt") or fn_name.endswith(".bin")):
print("This function is designed to load pytorch pth format files.")
return None
else: