mirror of https://github.com/Jittor/Jittor
LLM support with jtorch
This commit is contained in:
parent
b25e62f1bb
commit
90850d36ad
|
@ -87,7 +87,7 @@ def safeunpickle(path):
|
|||
from jittor_utils.misc import download_url_to_local
|
||||
download_url_to_local(path, base, compiler.ck_path, None)
|
||||
path = fname
|
||||
if path.endswith(".pth"):
|
||||
if path.endswith(".pth") or path.endswith(".pt") or path.endswith(".bin") :
|
||||
from jittor_utils.load_pytorch import load_pytorch
|
||||
model_dict = load_pytorch(path)
|
||||
return model_dict
|
||||
|
@ -371,7 +371,8 @@ def array(data, dtype=None):
|
|||
dtype = str(dtype)
|
||||
elif callable(dtype):
|
||||
dtype = dtype.__name__
|
||||
ret = ops.array(np.array(data, dtype))
|
||||
with jt.flag_scope(auto_convert_64_to_32=0):
|
||||
ret = ops.array(np.array(data, dtype))
|
||||
else:
|
||||
ret = ops.array(data)
|
||||
# TODO: move those code to core
|
||||
|
@ -460,6 +461,10 @@ def ones(*shape, dtype="float32"):
|
|||
shape = shape[0]
|
||||
return unary(1, dtype).broadcast(shape)
|
||||
|
||||
def new_ones(x, size):
|
||||
return ones(size, x.dtype)
|
||||
Var.new_ones = new_ones
|
||||
|
||||
def ones_like(x):
|
||||
''' Constructs a jittor Var with all elements set to 1 and shape same with x.
|
||||
|
||||
|
@ -487,6 +492,22 @@ def zeros(*shape, dtype="float32"):
|
|||
shape = shape[0]
|
||||
return unary(0, dtype).broadcast(shape)
|
||||
|
||||
def new_zeros(x, size):
|
||||
return zeros(size, x.dtype)
|
||||
Var.new_zeros = new_zeros
|
||||
|
||||
def empty(*shape, dtype="float32"):
|
||||
if isinstance(shape, tuple) and isinstance(shape[-1], (str, NanoString)):
|
||||
dtype = shape[-1]
|
||||
shape = shape[:-1]
|
||||
if isinstance(shape, tuple) and isinstance(shape[0], (Sequence, NanoVector)):
|
||||
shape = shape[0]
|
||||
return ops.empty(shape, dtype)
|
||||
|
||||
def new_empty(x, size):
|
||||
return empty(size, x.dtype)
|
||||
Var.new_empty = new_empty
|
||||
|
||||
def full(shape,val,dtype="float32"):
|
||||
''' Constructs a jittor Var with all elements set to val.
|
||||
|
||||
|
@ -503,6 +524,10 @@ def full(shape,val,dtype="float32"):
|
|||
shape = (shape,)
|
||||
return unary(val, dtype).broadcast(shape)
|
||||
|
||||
def new_full(x, size, val):
|
||||
return full(size, val, x.dtype)
|
||||
Var.new_full = new_full
|
||||
|
||||
def full_like(x, val, dtype=None) -> Var:
|
||||
''' Constructs a jittor Var with all elements set to val and shape same with x.
|
||||
|
||||
|
@ -739,8 +764,6 @@ Var.type_as = type_as
|
|||
Var.astype = Var.cast
|
||||
|
||||
def masked_fill(x, mask, value):
|
||||
assert list(x.shape) == list(mask.shape)
|
||||
# TODO: assert mask = 0 or 1
|
||||
return x * (1 - mask) + mask * value
|
||||
Var.masked_fill = masked_fill
|
||||
|
||||
|
@ -1271,6 +1294,19 @@ class Module:
|
|||
Loads the module's parameters from a dictionary.
|
||||
'''
|
||||
self.load_parameters(params)
|
||||
|
||||
def _load_from_state_dict(self, state, prefix="", *args, **kw):
|
||||
if len(prefix):
|
||||
new_state = {}
|
||||
for k,v in state.items():
|
||||
if k.startswith(prefix):
|
||||
new_state[k[len(prefix):]] = v
|
||||
state = new_state
|
||||
self.load_state_dict(state)
|
||||
|
||||
def cuda(self):
|
||||
flags.use_cuda = 1
|
||||
return self
|
||||
|
||||
def modules(self) -> List:
|
||||
''' Returns a list of sub-modules in the module recursively.
|
||||
|
@ -1625,20 +1661,44 @@ Arguments of hook are defined as::
|
|||
def __getattr__(self, key):
|
||||
return object.__getattribute__(self, key)
|
||||
|
||||
def register_buffer(self, key, value):
|
||||
object.__setattr__(self, key, value)
|
||||
return value
|
||||
|
||||
def float64(self):
|
||||
'''convert all parameters to float16'''
|
||||
self._amp_level = 0
|
||||
for p in self.parameters():
|
||||
if p.dtype.is_float():
|
||||
p.assign(p.float64())
|
||||
return self
|
||||
|
||||
def float32(self):
|
||||
'''convert all parameters to float16'''
|
||||
self._amp_level = 0
|
||||
for p in self.parameters():
|
||||
if p.dtype.is_float():
|
||||
p.assign(p.float32())
|
||||
return self
|
||||
|
||||
def float16(self):
|
||||
'''convert all parameters to float16'''
|
||||
self._amp_level = 4
|
||||
cls = self.__class__
|
||||
cls.__call__ = cls.__half_call__
|
||||
for p in self.parameters():
|
||||
if p.dtype.is_float():
|
||||
p.assign(p.float16())
|
||||
return self
|
||||
|
||||
def __half_call__(self, *args, **kw):
|
||||
amp_level = getattr(self, "_amp_level", -1)
|
||||
if amp_level >= 0:
|
||||
with flag_scope(amp_level=amp_level):
|
||||
return self.execute(*args, **kw)
|
||||
else:
|
||||
return self.execute(*args, **kw)
|
||||
|
||||
def half(self):
|
||||
'''convert all parameters to float16'''
|
||||
return self.float16()
|
||||
|
@ -1646,6 +1706,7 @@ Arguments of hook are defined as::
|
|||
def float_auto(self):
|
||||
'''convert all parameters to float16 or float32 automatically
|
||||
by jt.flags.auto_mixed_precision_level and jt.flags.amp_reg'''
|
||||
self._amp_level = -1
|
||||
for p in self.parameters():
|
||||
if p.dtype.is_float():
|
||||
p.assign(p.float_auto())
|
||||
|
@ -1890,12 +1951,10 @@ Var.size = size
|
|||
|
||||
|
||||
def to_int(v):
|
||||
assert v.dtype.is_int()
|
||||
return v.item()
|
||||
return ori_int(v.item())
|
||||
|
||||
def to_float(v):
|
||||
assert v.dtype.is_float()
|
||||
return v.item()
|
||||
return ori_float(v.item())
|
||||
|
||||
def to_bool(v):
|
||||
assert v.dtype.is_int() or v.dtype.is_bool()
|
||||
|
@ -1937,19 +1996,44 @@ from . import nn
|
|||
from . import attention
|
||||
from . import lr_scheduler
|
||||
from . import linalg
|
||||
from .linalg import einsum
|
||||
from .nn import matmul, \
|
||||
bmm, bmm_transpose
|
||||
bmm, bmm_transpose, \
|
||||
baddbmm
|
||||
from . import contrib
|
||||
from . import numpy2cupy
|
||||
from .contrib import concat
|
||||
from .contrib import concat, cat
|
||||
from .misc import *
|
||||
from . import sparse
|
||||
from . import optim
|
||||
from . import dataset
|
||||
from . import init
|
||||
|
||||
dtype = NanoString
|
||||
|
||||
import jittor_utils
|
||||
|
||||
for backend in jittor_utils.backends:
|
||||
if hasattr(backend, "post_process"):
|
||||
backend.post_process()
|
||||
|
||||
# impl x.func(...) -> func_(...)
|
||||
args = {"x", "input", "self"}
|
||||
_white_list = {"mul", "add", "sub"}
|
||||
for k,v in list(Var.__dict__.items()):
|
||||
if k.startswith("_"): continue
|
||||
if k.endswith("_"): continue
|
||||
if not callable(v): continue
|
||||
|
||||
if k not in _white_list:
|
||||
if not hasattr(v, "__code__"): continue
|
||||
conames = v.__code__.co_varnames
|
||||
if len(conames) == 0: continue
|
||||
arg_name = conames[0]
|
||||
if arg_name not in args: continue
|
||||
|
||||
new_k = k+"_"
|
||||
if hasattr(Var, new_k): continue
|
||||
def inplace_wrapper(new_k, prev_func):
|
||||
setattr(Var, new_k, lambda x, *args, **kw: x.assign(prev_func(x, *args, **kw)))
|
||||
inplace_wrapper(new_k, v)
|
||||
|
|
|
@ -270,3 +270,5 @@ Example::
|
|||
# s = jt.setitem(s, tuple(slices), a)
|
||||
cdim += a.shape[dim]
|
||||
return s
|
||||
|
||||
cat = concat
|
||||
|
|
|
@ -163,16 +163,20 @@ string process_acl(const string& src, const string& name, const map<string,strin
|
|||
new_src = "#include <Python.h>\n#include <pystate.h>\n"+
|
||||
replace(new_src, "op->do_run_after_prepare(jkl);",
|
||||
R"({
|
||||
auto state = _PyThreadState_UncheckedGet();
|
||||
Py_BEGIN_ALLOW_THREADS
|
||||
op->do_run_after_prepare(jkl);
|
||||
if (!_PyThreadState_UncheckedGet()) {
|
||||
PyEval_AcquireThread(state);
|
||||
}
|
||||
Py_END_ALLOW_THREADS
|
||||
})");
|
||||
}
|
||||
if (name == "profiler.cc") {
|
||||
new_src = token_replace_all(new_src, ".cc", ".tikcc");
|
||||
}
|
||||
// LOGir << name << (name == "pass_manager.cc");
|
||||
if (name == "pass_manager.cc") {
|
||||
LOGir << "replace" << name;
|
||||
new_src = token_replace_all(new_src, "run_pass<FloatAtomicFixPass>();", "WTF");
|
||||
}
|
||||
// ????????
|
||||
return new_src;
|
||||
}
|
||||
|
||||
|
@ -201,13 +205,42 @@ void acl_jittor_op_compiler(string& filename, string& src, bool is_acl, string&
|
|||
src = new_src;
|
||||
|
||||
new_src = token_replace_all(new_src, "atomicAdd(&$1,$2);", "$1=$1+$2;");
|
||||
new_src = token_replace_all(new_src, "bool", "int8");
|
||||
new_src = token_replace_all(new_src, "::numeric_min<float32>()", "-1e30");
|
||||
new_src = token_replace_all(new_src, "::numeric_max<float32>()", "1e30");
|
||||
// TODO: support max
|
||||
// new_src = token_replace_all(new_src, "::max($1,$2);", "($1)>($2)?($1):($2);");
|
||||
unordered_map<string,string> opmap = {
|
||||
// {"::max","tikcc::scalar_max"},
|
||||
{"::sqrtf", "tikcc::scalar_sqrt"}
|
||||
};
|
||||
auto ss = split(new_src, ";");
|
||||
for (auto &s : ss) {
|
||||
if (s.find("?") != string::npos) {
|
||||
s = token_replace_all(s+";", "auto $1=$2?$3:$4;", "auto $1=$3;if (!($2)) $1=$4;");
|
||||
}
|
||||
if (s.find("::max") != string::npos) {
|
||||
if (s.find("auto") == string::npos) {
|
||||
s = token_replace_all(s+";", " $1=$4::max($2,$3);", " $1=$2;if ($2 < $3) $1=$3;");
|
||||
} else {
|
||||
s = token_replace_all(s+";", "auto $1=$4::max($2,$3);", "auto $1=$2;if ($2 < $3) $1=$3;");
|
||||
}
|
||||
}
|
||||
for (auto& kv : opmap) {
|
||||
if (s.find(kv.first) != string::npos) {
|
||||
if (s.find("auto") == string::npos) {
|
||||
// $1 = op($2) --> op($1, $2)
|
||||
s = token_replace_all(s+";", " $1= "+kv.first+"($2);", kv.second+"($1, $2);");
|
||||
} else {
|
||||
// auto $1 = op($2) --> float32 $1; op($1, $2);
|
||||
s = token_replace_all(s+";", "auto $1= "+kv.first+"($2);", "float32 $1; " + kv.second+"($1, $2);");
|
||||
}
|
||||
}
|
||||
}
|
||||
// s = token_replace_all(s+";", "auto $1=$2?$3:$4;", "auto $1=$3;if (!($2)) $1=$4;");
|
||||
// s = token_replace_all(s+";", "auto $1=$2?$3:$4;", "auto $1=$3;if (!($2)) $1=$4;");
|
||||
// if (s.find("::max") != string::npos) {
|
||||
// s = token_replace_all(s+";", " $1= ::max($2);", "tikcc::scalar_max($1, $2);");
|
||||
// }
|
||||
}
|
||||
new_src = join(ss, ";");
|
||||
src = new_src;
|
||||
|
|
|
@ -112,6 +112,7 @@ def constant_(var, value=0.0):
|
|||
'''
|
||||
return var.assign(constant(var.shape, var.dtype, value))
|
||||
Var.constant_ = constant_
|
||||
fill = Var.fill_ = constant_
|
||||
|
||||
def zero(shape, dtype="float32"):
|
||||
'''Generate zero Jittor Var.
|
||||
|
|
|
@ -438,6 +438,8 @@ def einsum(string, *args):
|
|||
:return: return values depend on the input string kinds.
|
||||
"""
|
||||
import numpy as np_cpu
|
||||
if string == "i,j->ij":
|
||||
return args[0].broadcast((args[0].shape[0], args[1].shape[0]), dims=[1]).multiply(args[1])
|
||||
def forward_code(np, data):
|
||||
out = data["outputs"][0]
|
||||
npout = np.einsum(string, *data["inputs"])
|
||||
|
|
|
@ -128,6 +128,25 @@ def __iter__(x):
|
|||
return result.__iter__()
|
||||
jt.Var.__iter__ = __iter__
|
||||
|
||||
def __contains__(x, key):
|
||||
return bool((x == key).any())
|
||||
jt.Var.__contains__ = __contains__
|
||||
|
||||
def new(x, *args):
|
||||
if len(args) != 1 or isinstance(args[0], int):
|
||||
return jt.empty(args, x.dtype)
|
||||
return jt.array(args[0]).cast(x.dtype)
|
||||
jt.Var.new = new
|
||||
|
||||
def __index__(x):
|
||||
return int(x.item())
|
||||
jt.Var.__index__ = __index__
|
||||
|
||||
def sort(input, dim=-1, descending=False, stable=False):
|
||||
index, value = jt.argsort(input, dim, descending)
|
||||
return value, index
|
||||
jt.Var.sort = sort
|
||||
|
||||
def all(x, dim=()):
|
||||
return x.all_(dim).bool()
|
||||
jt.Var.all = all
|
||||
|
@ -787,6 +806,9 @@ jt.Var.nonzero = nonzero
|
|||
|
||||
|
||||
def arange(start=0, end=None, step=1,dtype=None):
|
||||
if isinstance(start, jt.Var): start = start.item()
|
||||
if isinstance(end, jt.Var): end = end.item()
|
||||
if isinstance(step, jt.Var): step = step.item()
|
||||
if end is None:
|
||||
end,start = start,0
|
||||
l = round((end-start)//step)+1
|
||||
|
@ -2015,6 +2037,7 @@ def triu(input: jt.Var, diagonal:int=0) -> jt.Var:
|
|||
mask = index[-2] <= index[-1] - diagonal
|
||||
return input*mask
|
||||
jt.Var.triu = triu
|
||||
jt.Var.triu_ = lambda x: x.assign(x.triu())
|
||||
|
||||
def tril(input: jt.Var, diagonal:int=0) -> jt.Var:
|
||||
''' Returns the lower triangular part of a matrix (2-D tensor) or batch of matrices input, the other elements of the result tensor out are set to 0.
|
||||
|
@ -2039,6 +2062,7 @@ def tril(input: jt.Var, diagonal:int=0) -> jt.Var:
|
|||
mask = index[-2] >= index[-1] - diagonal
|
||||
return input*mask
|
||||
jt.Var.tril = tril
|
||||
jt.Var.tril_ = lambda x: x.assign(x.tril())
|
||||
|
||||
def all_equal(a: jt.Var, b: jt.Var) -> bool:
|
||||
return (a == b).all().item()
|
||||
|
|
|
@ -73,6 +73,12 @@ Example::
|
|||
assert len(a.shape) > 2 and len(b.shape) > 2
|
||||
return matmul(a, b)
|
||||
|
||||
def baddbmm(input, batch1, batch2, beta=1, alpha=1):
|
||||
res = bmm(batch1, batch2)
|
||||
if alpha != 1: res = res * alpha
|
||||
if beta == 0: return res
|
||||
return beta * input + res
|
||||
|
||||
def matmul(a, b):
|
||||
''' matrix multiply,
|
||||
|
||||
|
@ -1683,15 +1689,18 @@ class Embedding(Module):
|
|||
[ 0.14941819 0.57047683 -1.3217674]
|
||||
[ 0.14941819 0.57047683 -1.3217674]], dtype=float32)
|
||||
'''
|
||||
def __init__(self, num, dim):
|
||||
self.num = num
|
||||
self.dim = dim
|
||||
self.weight = jt.init.gauss([num,dim],'float32').stop_grad()
|
||||
def __init__(self, num_embeddings, embedding_dim, dtype="float32"):
|
||||
self.num = num_embeddings
|
||||
self.dim = embedding_dim
|
||||
self.weight = jt.init.gauss([self.num, self.dim], dtype).stop_grad()
|
||||
|
||||
def execute(self, x):
|
||||
res = self.weight[x.flatten()].reshape(x.shape + [self.dim])
|
||||
res = self.weight[x]
|
||||
return res
|
||||
|
||||
def embedding(input, weight):
|
||||
return weight[input]
|
||||
|
||||
class PixelShuffle(Module):
|
||||
def __init__(self, upscale_factor):
|
||||
self.upscale_factor = upscale_factor
|
||||
|
@ -2973,3 +2982,19 @@ class KLDivLoss(Module):
|
|||
else:
|
||||
loss = loss_pointwise
|
||||
return loss
|
||||
|
||||
class Mish(Module):
|
||||
def __init__(self, inplace=False):
|
||||
'''
|
||||
Applies the Mish function, element-wise.
|
||||
reference: Mish - A Self Regularized Non-Monotonic Neural Activation Function.
|
||||
'''
|
||||
pass
|
||||
def execute(self, x):
|
||||
return x * jt.tanh(jt.softplus(x))
|
||||
|
||||
def mish(x, inplace=False):
|
||||
return x * jt.tanh(jt.softplus(x))
|
||||
|
||||
def skip_init(module_cls, *args, **kw):
|
||||
return module_cls(*args, **kw)
|
||||
|
|
|
@ -778,7 +778,7 @@ def compile_src(src, h, basename):
|
|||
if submodule_info and "attrs" in submodule_info and "core_name" in submodule_info["attrs"]:
|
||||
core_name = submodule_info["attrs"]["core_name"]
|
||||
has_map = class_name in ["VarHolder", "NanoVector"]
|
||||
has_seq = class_name == "NanoVector"
|
||||
has_seq = class_name in ["VarHolder", "NanoVector"]
|
||||
# add extra include to avoid compile error
|
||||
src_code = ""
|
||||
if include_name.endswith("var_slices.h"):
|
||||
|
|
|
@ -91,7 +91,7 @@ unordered_set<string> binary_ops = {
|
|||
* [in] y: the second input, a python number or jt.Var.
|
||||
|
||||
*/
|
||||
// @pybind(subtract, __sub__)
|
||||
// @pybind(subtract, __sub__, sub)
|
||||
"subtract",
|
||||
|
||||
/**
|
||||
|
@ -106,7 +106,7 @@ unordered_set<string> binary_ops = {
|
|||
* [in] y: the second input, a python number or jt.Var.
|
||||
|
||||
*/
|
||||
// @pybind(multiply, __mul__)
|
||||
// @pybind(multiply, __mul__, mul)
|
||||
"multiply",
|
||||
|
||||
/**
|
||||
|
@ -138,7 +138,7 @@ unordered_set<string> binary_ops = {
|
|||
returns float value even if the dtype of input Vars are both integers.
|
||||
@see jt.ops.floor_divide() for floor division.
|
||||
*/
|
||||
// @pybind(divide, __truediv__)
|
||||
// @pybind(divide, __truediv__, div)
|
||||
"divide",
|
||||
|
||||
/**
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include "ops/op_register.h"
|
||||
#include "ops/getitem_op.h"
|
||||
#include "ops/setitem_op.h"
|
||||
#include "type/fp16_compute.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
|
@ -184,6 +185,13 @@ ArrayArgs VarHolder::fetch_sync() {
|
|||
return {var->mem_ptr, var->shape, var->dtype()};
|
||||
}
|
||||
|
||||
inline static void cast_item_data(ItemData& data) {
|
||||
auto* fp16 = (float16*)&data;
|
||||
auto* fp32 = (float32*)&data;
|
||||
fp32[0] = float32(fp16[0]);
|
||||
data.dtype = ns_float32;
|
||||
}
|
||||
|
||||
ItemData VarHolder::item() {
|
||||
sync();
|
||||
CHECK(var->num==1) << "Item var size should be 1, but got" << var->num;
|
||||
|
@ -199,6 +207,8 @@ ItemData VarHolder::item() {
|
|||
{
|
||||
std::memcpy(&data.data, var->mem_ptr, dsize);
|
||||
}
|
||||
if (data.dtype == ns_float16)
|
||||
cast_item_data(data);
|
||||
return data;
|
||||
}
|
||||
|
||||
|
|
|
@ -260,7 +260,7 @@ struct VarHolder {
|
|||
/**
|
||||
* return the number of dimensions.
|
||||
*/
|
||||
// @pyjt(__get__ndim)
|
||||
// @pyjt(__get__ndim, dim)
|
||||
inline int ndim() {
|
||||
return var->shape.size();
|
||||
}
|
||||
|
|
|
@ -452,7 +452,7 @@ def to_tensor(pic):
|
|||
return np.float32(img) * np.float32(1/255.0)
|
||||
else:
|
||||
return img
|
||||
|
||||
pil_to_tensor = to_tensor
|
||||
|
||||
|
||||
def _to_jittor_array(pic):
|
||||
|
|
|
@ -211,7 +211,7 @@ def load_pytorch(fn_name):
|
|||
global contents, deserialized_objects, loaded_storages
|
||||
loaded_storages = {}
|
||||
deserialized_objects = {}
|
||||
if not fn_name.endswith(".pth"):
|
||||
if not (fn_name.endswith(".pth") or fn_name.endswith(".pt") or fn_name.endswith(".bin")):
|
||||
print("This function is designed to load pytorch pth format files.")
|
||||
return None
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue