chatglm optimize v2

This commit is contained in:
Dun Liang 2023-03-31 13:07:47 +08:00
parent 1e00883ed1
commit faa5386d82
19 changed files with 9550 additions and 40 deletions

View File

@ -1335,7 +1335,7 @@ class Module:
state = new_state
self.load_state_dict(state)
def cuda(self):
def cuda(self, device=None):
flags.use_cuda = 1
return self
@ -1585,10 +1585,10 @@ Arguments of hook are defined as::
if param.shape == v.shape:
LOG.v(f'load parameter {key} success ...')
v.update(param)
v.sync(False, False)
else:
n_failed += 1
LOG.e(f'load parameter {key} failed: expect the shape of {key} to be {v.shape}, but got {param.shape}')
jt.sync_all()
if n_failed:
LOG.w(f"load total {len(params)} params, {n_failed} failed")
@ -1680,6 +1680,16 @@ Arguments of hook are defined as::
if not hasattr(self, "is_train"):
self.is_train = True
return self.is_train
@property
def training(self):
if not hasattr(self, "is_train"):
self.is_train = True
return self.is_train
@training.setter
def training(self, value):
self.is_train = value
def mpi_param_broadcast(self, root=0):
if not in_mpi: return

View File

@ -2125,7 +2125,7 @@ def multinomial(weights: jt.Var, num_samples: int, replacement: bool=False) -> j
# A-Res algorithm
# Pavlos S. Efraimidis and Paul G. Spirakis, 2006, Weighted random sampling with a reservoir
assert num_samples <= weights.shape[-1], "num_samples larger than the input"
rand = jt.rand(weights.shape) ** (1/weights)
rand = jt.rand(weights.shape) ** ((1/weights).safe_clip())
_, indices = jt.topk(rand.safe_clip(), num_samples)
return indices

7801
python/jittor/src/misc/miniz.cc Executable file

File diff suppressed because it is too large Load Diff

1376
python/jittor/src/misc/miniz.h Executable file

File diff suppressed because it is too large Load Diff

View File

@ -17,7 +17,11 @@ namespace jittor {
__global__ void _check_nan_float16(__half* __restrict__ ptr, int64 num) {
int64 i = threadIdx.x + blockIdx.x * (int64)blockDim.x;
if (i<num) {
#if JT_CHECK_NAN == 2
if (isnan(__half2float(ptr[i])))
#else
if (isnan(__half2float(ptr[i])) || __hisinf(ptr[i]))
#endif
__trap();
}
}
@ -25,7 +29,11 @@ __global__ void _check_nan_float16(__half* __restrict__ ptr, int64 num) {
__global__ void _check_nan_float32(float32* __restrict__ ptr, int64 num) {
int64 i = threadIdx.x + blockIdx.x * (int64)blockDim.x;
if (i<num) {
#if JT_CHECK_NAN == 2
if (::isnan(ptr[i]))
#else
if (::isnan(ptr[i]) || ::isinf(ptr[i]))
#endif
__trap();
}
}
@ -34,7 +42,11 @@ __global__ void _check_nan_float32(float32* __restrict__ ptr, int64 num) {
__global__ void _check_nan_float64(float64* __restrict__ ptr, int64 num) {
int64 i = threadIdx.x + blockIdx.x * (int64)blockDim.x;
if (i<num) {
#if JT_CHECK_NAN == 2
if (::isnan(ptr[i]))
#else
if (::isnan(ptr[i]) || ::isinf(ptr[i]))
#endif
__trap();
}
}

View File

@ -195,9 +195,11 @@ constexpr int amp_keep_reduce = 4;
constexpr int amp_keep_white = 8;
constexpr int amp_array_prefer = 16;
inline NanoString float_dtype(int dsize_) {
if (amp_reg & amp_prefer32) return ns_float32;
if (amp_reg & amp_prefer16) return ns_float16;
inline NanoString float_dtype(int dsize_, bool has_scalar=false) {
if (!has_scalar) {
if (amp_reg & amp_prefer32) return ns_float32;
if (amp_reg & amp_prefer16) return ns_float16;
}
return (dsize_ == 3) ? ns_float64 :
(dsize_ == 2 ) ? ns_float32 : ns_float16;
}
@ -208,25 +210,29 @@ inline NanoString int_dtype(int dsize_) {
(dsize_ == 1) ? ns_int16 : ns_int8;
}
inline NanoString dtype_infer(NanoString x, NanoString y) {
inline NanoString dtype_infer(NanoString x, NanoString y, bool xscalar=false, bool yscalar=false) {
int dsize_ = std::max(x.dsize_(), y.dsize_());
if (xscalar) dsize_ = y.dsize_();
if (yscalar) dsize_ = x.dsize_();
bool is_float = x.is_float() || y.is_float();
if (is_float)
return float_dtype(dsize_);
return float_dtype(dsize_, xscalar||yscalar);
else {
return int_dtype(dsize_);
}
}
inline NanoString binary_dtype_infer(NanoString op, NanoString x, NanoString y) {
inline NanoString binary_dtype_infer(NanoString op, NanoString x, NanoString y, bool xscalar=false, bool yscalar=false) {
if (op.is_bool()) return ns_bool;
int dsize_ = std::max(x.dsize_(), y.dsize_());
if (xscalar) dsize_ = y.dsize_();
if (yscalar) dsize_ = x.dsize_();
bool is_float = !op.is_int() &&
(x.is_float() || y.is_float() || op.is_float());
if (is_float) {
if (op.is_white() && !(amp_reg & amp_keep_white))
return (dsize_ == 3) ? ns_float64 : ns_float32;
return float_dtype(dsize_);
return float_dtype(dsize_, xscalar||yscalar);
} else {
if (x.is_bool() && y.is_bool()) return ns_bool;
return int_dtype(dsize_);

View File

@ -45,6 +45,7 @@ struct NodeFlags {
_needed_by_backward=_n+3,
_out_hint=_n+4,
_th_require_grad=_n+5,
_is_scalar=_n+5,
// op related flags
// bit0: support cpu

View File

@ -53,6 +53,7 @@ ArrayOp::ArrayOp(ArrayArgs&& args) {
NanoVector shape = output->shape;
if (shape.size() == 1 && shape[0] == 1) {
output->flags.set(NodeFlags::_force_fuse);
output->flags.set(NodeFlags::_is_scalar);
set_type(OpType::element);
}
#ifdef HAS_CUDA

View File

@ -445,7 +445,7 @@ BinaryOp::BinaryOp(Var* x, Var* y, NanoString op) : x(x), y(y) {
set_type(OpType::element);
ns = op;
ASSERT(ns.is_binary());
z = create_output(x->shape, binary_dtype_infer(op, x->ns, y->ns));
z = create_output(x->shape, binary_dtype_infer(op, x->ns, y->ns, x->flags.get(NodeFlags::_is_scalar), y->flags.get(NodeFlags::_is_scalar)));
bool bin = ns.get(NanoString::_no_need_back_in);
bool bout = ns.get(NanoString::_no_need_back_out);
if (bin || bout) {

View File

@ -163,6 +163,7 @@ void BroadcastToOp::infer_shape() {
NanoVector zshape;
for (int i=0; i<zdim; i++) zshape.push_back(zz[i]);
z->set_shape(zshape);
z->flags.set(NodeFlags::_is_scalar, x->flags.get(NodeFlags::_is_scalar));
LOGvvv << "Broadcast x(" >> x >> ") shape" << yshapes << "-> z(" >> z >> ")";
}

View File

@ -23,8 +23,6 @@ static auto make_ternary = get_op_info("ternary")
static auto make_number = get_op_info("number")
.get_constructor<VarPtr, float, Var*>();
NanoString binary_dtype_infer(NanoString op, Var* dx, Var* dy);
unordered_set<string> reduce_ops = {
/**
Returns the maximum elements in the input.
@ -310,6 +308,8 @@ void ReduceOp::infer_shape() {
keepdims_mask |= 1;
}
y->set_shape(yshape);
if (yshape.size() == 1 && y->num == 1)
y->flags.set(NodeFlags::_is_scalar);
}
VarPtr ReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) {

View File

@ -42,7 +42,7 @@ TernaryOp::TernaryOp(Var* cond, Var* x, Var* y) : cond(cond), x(x), y(y) {
if (x->dtype() == y->dtype()) {
z = create_output(nullptr, x->dtype());
} else {
z = create_output(nullptr, dtype_infer(x->ns, y->ns));
z = create_output(nullptr, dtype_infer(x->ns, y->ns, x->flags.get(NodeFlags::_is_scalar), y->flags.get(NodeFlags::_is_scalar)));
}
}

View File

@ -854,6 +854,7 @@ UnaryOp::UnaryOp(Var* x, NanoString op) : x(x) {
} else
dtype = unary_dtype_infer(ns, x->ns);
y = create_output(nullptr, dtype);
y->flags.set(NodeFlags::_is_scalar, x->flags.get(NodeFlags::_is_scalar));
bool bin = ns.get(NanoString::_no_need_back_in);
bool bout = ns.get(NanoString::_no_need_back_out);
if (bin || bout) {

View File

@ -119,6 +119,7 @@ ArrayOp::ArrayOp(PyObject* obj) {
int64 size = output->size;
if (shape.size() == 1 && shape[0] == 1) {
output->flags.set(NodeFlags::_force_fuse);
output->flags.set(NodeFlags::_is_scalar);
set_type(OpType::element);
}
void* host_ptr = nullptr;

View File

@ -196,6 +196,25 @@ DEF_IS(MemInfo, const T&) from_py_object(PyObject* obj) {
return GET_RAW_PTR(T, obj);
}
// MemInfo
struct ZipFile;
EXTERN_LIB PyTypeObject PyjtZipFile;
DEF_IS(ZipFile, bool) is_type(PyObject* obj) {
return Py_TYPE(obj) == &PyjtZipFile;
}
DEF_IS(ZipFile, PyObject*) to_py_object(const T& a) {
PyObjHolder obj(_PyObject_New(&PyjtZipFile));
auto ptr = GET_RAW_PTR(T, obj.obj);
new (ptr) T(a);
return obj.release();
}
DEF_IS(ZipFile, const T&) from_py_object(PyObject* obj) {
return GET_RAW_PTR(T, obj);
}
// NanoString
struct NanoString;

View File

@ -251,12 +251,21 @@ struct VarHolder {
*/
// @pyjt(__get__data)
inline DataView data() {
sync(true);
sync(true, false);
#ifdef HAS_CUDA
migrate_to_cpu(var, exe.allocator);
#endif
return {this, var->mem_ptr, var->shape, var->dtype()};
}
// @pyjt(__get__raw_ptr)
inline uint64 raw_ptr() {
sync(true, false);
#ifdef HAS_CUDA
migrate_to_cpu(var, exe.allocator);
#endif
return (uint64)var->mem_ptr;
}
/**
* returns the Python number if the Var contains only one element.

View File

@ -280,6 +280,13 @@ class TestFP16(unittest.TestCase):
a.half()
assert a.weight.dtype == "float16"
def test_scalar(self):
a = jt.float16([1,2,3])
assert (a*1).dtype == "float16"
assert (a*jt.float16([1,2,3])).dtype == "float16"
assert (a*jt.float32([1,2,3])).dtype == "float32"
assert (a*jt.float32([1,2,3]).sum()).dtype == "float16"
assert jt.int([0,1,0]).ternary(a, jt.float32(1)).dtype == "float16"
@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found")

View File

@ -9,21 +9,6 @@ from typing import Any, BinaryIO, cast, Dict, Optional, Type, Tuple, Union, IO,
loaded_storages = {}
deserialized_objects = {}
def _is_zipfile(fn):
f = open(fn, "rb")
read_bytes = []
start = f.tell()
byte = f.read(1)
while byte != "":
read_bytes.append(byte)
if len(read_bytes) == 4:
break
byte = f.read(1)
f.seek(start)
local_header_magic_number = [b'P', b'K', b'\x03', b'\x04']
return read_bytes == local_header_magic_number
def _maybe_decode_ascii(bytes_str: Union[bytes, str]) -> str:
if isinstance(bytes_str, bytes):
@ -31,8 +16,8 @@ def _maybe_decode_ascii(bytes_str: Union[bytes, str]) -> str:
return bytes_str
def load_tensor(contents, dtype, numel, key, location):
name = os.path.join("archive", "data", str(key))
loaded_storages[key] = np.frombuffer(contents[name], dtype).copy()
name = os.path.join(prefix, "data", str(key))
loaded_storages[key] = contents.read_var(name, dtype)
def get_dtype_size(dtype):
dtype = dtype.__str__()
@ -208,22 +193,31 @@ def persistent_load_direct(saved_id):
raise RuntimeError("Unknown saved id type: %s" % saved_id[0])
def load_pytorch(fn_name):
global contents, deserialized_objects, loaded_storages
import jittor as jt
global contents, deserialized_objects, loaded_storages, prefix
loaded_storages = {}
deserialized_objects = {}
if not (fn_name.endswith(".pth") or fn_name.endswith(".pt") or fn_name.endswith(".bin")):
print("This function is designed to load pytorch pth format files.")
return None
else:
if _is_zipfile(fn_name):
contents = jt.ZipFile(fn_name)
if contents.valid():
loaded_storages = {}
deserialized_objects = {}
contents = extract_zip(fn_name)
data_file = io.BytesIO(contents['archive/data.pkl'])
pickle_load_args = {'encoding': 'utf-8'}
unpickler = UnpicklerWrapper(data_file, **pickle_load_args)
unpickler.persistent_load = persistent_load
result = unpickler.load()
for name in contents.list():
if "data.pkl" in name:
prefix = name[:-8]
break
else:
raise RuntimeError(f"zipfile <{fn_name}> format error, data.pkl not found")
with jt.flag_scope(use_cuda=0):
data_file = contents.read_var(prefix+"data.pkl").data.tobytes()
data_file = io.BytesIO(data_file)
pickle_load_args = {'encoding': 'utf-8'}
unpickler = UnpicklerWrapper(data_file, **pickle_load_args)
unpickler.persistent_load = persistent_load
result = unpickler.load()
else:
deserialized_objects = {}
f = open(fn_name, "rb")

View File

@ -0,0 +1,271 @@
import pickle
import os
import io
import shutil
from zipfile import ZipFile
import jittor as jt
import numpy as np
from typing import Any, BinaryIO, cast, Dict, Optional, Type, Tuple, Union, IO, List
loaded_storages = {}
deserialized_objects = {}
def _maybe_decode_ascii(bytes_str: Union[bytes, str]) -> str:
if isinstance(bytes_str, bytes):
return bytes_str.decode('ascii')
return bytes_str
def load_tensor(contents, dtype, numel, key, location):
name = os.path.join(prefix, "data", str(key))
loaded_storages[key] = contents.read_var(name, dtype)
def get_dtype_size(dtype):
dtype = dtype.__str__()
if dtype == "float32" or dtype == "int32":
return 4
if dtype == "float64" or dtype == "int64":
return 8
if dtype == "float16" or dtype == "int16":
return 2
return 1
def persistent_load(saved_id):
global contents
assert isinstance(saved_id, tuple)
typename = _maybe_decode_ascii(saved_id[0])
data = saved_id[1:]
assert typename == 'storage', \
f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"
storage_type, key, location, numel = data
dtype = storage_type.dtype
if key not in loaded_storages:
nbytes = numel
load_tensor(contents, dtype, nbytes, key, _maybe_decode_ascii(location))
return loaded_storages[key]
def _dtype_to_storage_type_map():
return {
np.float16: 'HalfStorage',
np.float32: 'FloatStorage',
np.int64: 'LongStorage',
np.int32: 'IntStorage',
np.int16: 'ShortStorage',
np.int8: 'CharStorage'
}
def _storage_type_to_dtype_map():
dtype_map = {
val: key for key, val in _dtype_to_storage_type_map().items()}
return dtype_map
def _get_dtype_from_pickle_storage_type(pickle_storage_type: str):
try:
return _storage_type_to_dtype_map()[pickle_storage_type]
except KeyError:
raise KeyError(
f'pickle storage type "{pickle_storage_type}" is not recognized')
class StorageType():
def __init__(self, name):
self.dtype = _get_dtype_from_pickle_storage_type(name)
def __str__(self):
return f'StorageType(dtype={self.dtype})'
def jittor_rebuild(storage, storage_offset, size, stride, requires_grad, backward_hooks):
if len(size) == 0:
return jt.array(storage)
record_size = np.prod(size)
return jt.array(storage[:record_size]).reshape(size)
def jittor_rebuild_var(data, requires_grad, backward_hooks):
v = jt.array(data)
v.requires_grad = requires_grad
return v
class UnpicklerWrapper(pickle.Unpickler): # type: ignore[name-defined]
def find_class(self, mod_name, name):
if type(name) is str and 'Storage' in name:
try:
return StorageType(name)
except KeyError:
pass
if type(name) is str and '_rebuild_tensor_v2' in name:
return super().find_class("jittor_utils.load_pytorch", "jittor_rebuild")
if type(name) is str and '_rebuild_parameter' in name:
return super().find_class("jittor_utils.load_pytorch", "jittor_rebuild_var")
return super().find_class(mod_name, name)
class ArrayWrapper:
def __init__(self, storage, stride=None, size=None, requires_grad=None):
self.requires_grad = requires_grad
self.size = size
self.storage = storage
self.stride = stride
def __str__(self):
return self.storage.__str__()
def jittor_rebuild_direct(storage, storage_offset, size, stride, requires_grad, backward_hooks):
if len(size) == 0:
return ArrayWrapper(storage, stride=stride, size=size)
storage.reshape(size)
return ArrayWrapper(storage, stride=stride, size=size)
def jittor_rebuild_var_direct(data, requires_grad, backward_hooks):
v = ArrayWrapper(storage, requires_grad=requires_grad)
return v
class DirectUnpicklerWrapper(pickle.Unpickler): # type: ignore[name-defined]
def find_class(self, mod_name, name):
if type(name) is str and 'Storage' in name:
try:
return StorageType(name)
except KeyError:
pass
if type(name) is str and '_rebuild_tensor_v2' in name:
return super().find_class("jittor_utils.load_pytorch", "jittor_rebuild_direct")
if type(name) is str and '_rebuild_parameter' in name:
return super().find_class("jittor_utils.load_pytorch", "jittor_rebuild_var_direct")
return super().find_class(mod_name, name)
def _check_seekable(f) -> bool:
def raise_err_msg(patterns, e):
for p in patterns:
if p in str(e):
msg = (str(e) + ". You can only load from a file that is seekable."
+ " Please pre-load the data into a buffer like io.BytesIO and"
+ " try to load from it instead.")
raise type(e)(msg)
raise e
try:
f.seek(f.tell())
return True
except (io.UnsupportedOperation, AttributeError) as e:
raise_err_msg(["seek", "tell"], e)
return False
def extract_zip(input_zip):
input_zip = ZipFile(input_zip)
return {name: input_zip.read(name) for name in input_zip.namelist()}
def _is_compressed_file(f):
compress_modules = ['gzip']
try:
return f.__module__ in compress_modules
except AttributeError:
return False
def _should_read_directly(f):
if _is_compressed_file(f):
return False
try:
return f.fileno() >= 0
except io.UnsupportedOperation:
return False
except AttributeError:
return False
def persistent_load_direct(saved_id):
global deserialized_objects
assert isinstance(saved_id, tuple)
typename = _maybe_decode_ascii(saved_id[0])
data = saved_id[1:]
if typename == 'module':
# Ignore containers that don't have any sources saved
return data[0]
elif typename == 'storage':
data_type, root_key, location, size, view_metadata = data
location = _maybe_decode_ascii(location)
if root_key not in deserialized_objects:
deserialized_objects[root_key] = np.zeros(size, dtype=data_type)
storage = deserialized_objects[root_key]
if view_metadata is not None:
view_key, offset, view_size = view_metadata
if view_key not in deserialized_objects:
deserialized_objects[view_key] = storage[offset:offset + view_size]
return deserialized_objects[view_key]
else:
return storage
else:
raise RuntimeError("Unknown saved id type: %s" % saved_id[0])
def load_pytorch(fn_name):
import jittor as jt
global contents, deserialized_objects, loaded_storages, prefix
loaded_storages = {}
deserialized_objects = {}
if not (fn_name.endswith(".pth") or fn_name.endswith(".pt") or fn_name.endswith(".bin")):
print("This function is designed to load pytorch pth format files.")
return None
else:
contents = jt.ZipFile(fn_name)
if contents.valid():
loaded_storages = {}
deserialized_objects = {}
for name in contents.list():
if "data.pkl" in name:
prefix = name[:-8]
break
else:
raise RuntimeError(f"zipfile <{fn_name}> format error, data.pkl not found")
with jt.flag_scope(use_cuda=0):
print("load??", fn_name)
data_file = contents.read_var(prefix+"data.pkl").data.tobytes()
data_file = io.BytesIO(data_file)
pickle_load_args = {'encoding': 'utf-8'}
unpickler = UnpicklerWrapper(data_file, **pickle_load_args)
unpickler.persistent_load = persistent_load
result = unpickler.load()
else:
deserialized_objects = {}
f = open(fn_name, "rb")
f_should_read_directly = _should_read_directly(f)
MAGIC_NUMBER = 0x1950a86a20f9469cfc6c
PROTOCOL_VERSION = 1001
pickle_load_args = {'encoding': 'utf-8'}
magic_number = pickle.load(f, **pickle_load_args)
if magic_number != MAGIC_NUMBER:
raise RuntimeError("Invalid magic number; corrupt file?")
protocol_version = pickle.load(f, **pickle_load_args)
if PROTOCOL_VERSION != protocol_version:
raise RuntimeError("Invalid protocal version.")
_sys_info = pickle.load(f, **pickle_load_args)
unpickler = DirectUnpicklerWrapper(f, **pickle_load_args)
unpickler.persistent_load = persistent_load_direct
result = unpickler.load()
offset = f.tell() if f_should_read_directly else None
deserialized_storage_keys = pickle.load(f, **pickle_load_args)
f.read(8)
for key in deserialized_storage_keys:
assert key in deserialized_objects
dtype = deserialized_objects[key].dtype
size = deserialized_objects[key].size * get_dtype_size(dtype)
byte_data = f.read(size)
deserialized_objects[key][:] = np.frombuffer(byte_data, dtype).copy()
f.read(8)
if offset is not None:
offset = f.tell()
for key, params in result.items():
requires_grad = params.requires_grad
shape = params.size
result[key] = jt.array(params.storage)
if shape is not None and len(shape) > 0:
if len(params.stride) > 1:
eval_list = []
for idx in range(len(params.stride)):
eval_list.append(f"@e0({idx}) * i{idx}")
evals = "+".join(eval_list)
result[key] = result[key].reindex(params.size, [evals], extras=[jt.array(params.stride)])
else:
result[key] = result[key].reshape(shape)
if requires_grad is not None:
result[key].requires_grad = requires_grad
return result
if __name__ == "__main__":
result = load_pytorch("van_base.pth")
for key, val in result.items():
print(key, val.shape)