mirror of https://github.com/Jittor/Jittor
chatglm optimize v2
This commit is contained in:
parent
1e00883ed1
commit
faa5386d82
|
@ -1335,7 +1335,7 @@ class Module:
|
|||
state = new_state
|
||||
self.load_state_dict(state)
|
||||
|
||||
def cuda(self):
|
||||
def cuda(self, device=None):
|
||||
flags.use_cuda = 1
|
||||
return self
|
||||
|
||||
|
@ -1585,10 +1585,10 @@ Arguments of hook are defined as::
|
|||
if param.shape == v.shape:
|
||||
LOG.v(f'load parameter {key} success ...')
|
||||
v.update(param)
|
||||
v.sync(False, False)
|
||||
else:
|
||||
n_failed += 1
|
||||
LOG.e(f'load parameter {key} failed: expect the shape of {key} to be {v.shape}, but got {param.shape}')
|
||||
jt.sync_all()
|
||||
if n_failed:
|
||||
LOG.w(f"load total {len(params)} params, {n_failed} failed")
|
||||
|
||||
|
@ -1680,6 +1680,16 @@ Arguments of hook are defined as::
|
|||
if not hasattr(self, "is_train"):
|
||||
self.is_train = True
|
||||
return self.is_train
|
||||
|
||||
@property
|
||||
def training(self):
|
||||
if not hasattr(self, "is_train"):
|
||||
self.is_train = True
|
||||
return self.is_train
|
||||
|
||||
@training.setter
|
||||
def training(self, value):
|
||||
self.is_train = value
|
||||
|
||||
def mpi_param_broadcast(self, root=0):
|
||||
if not in_mpi: return
|
||||
|
|
|
@ -2125,7 +2125,7 @@ def multinomial(weights: jt.Var, num_samples: int, replacement: bool=False) -> j
|
|||
# A-Res algorithm
|
||||
# Pavlos S. Efraimidis and Paul G. Spirakis, 2006, Weighted random sampling with a reservoir
|
||||
assert num_samples <= weights.shape[-1], "num_samples larger than the input"
|
||||
rand = jt.rand(weights.shape) ** (1/weights)
|
||||
rand = jt.rand(weights.shape) ** ((1/weights).safe_clip())
|
||||
_, indices = jt.topk(rand.safe_clip(), num_samples)
|
||||
return indices
|
||||
|
||||
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -17,7 +17,11 @@ namespace jittor {
|
|||
__global__ void _check_nan_float16(__half* __restrict__ ptr, int64 num) {
|
||||
int64 i = threadIdx.x + blockIdx.x * (int64)blockDim.x;
|
||||
if (i<num) {
|
||||
#if JT_CHECK_NAN == 2
|
||||
if (isnan(__half2float(ptr[i])))
|
||||
#else
|
||||
if (isnan(__half2float(ptr[i])) || __hisinf(ptr[i]))
|
||||
#endif
|
||||
__trap();
|
||||
}
|
||||
}
|
||||
|
@ -25,7 +29,11 @@ __global__ void _check_nan_float16(__half* __restrict__ ptr, int64 num) {
|
|||
__global__ void _check_nan_float32(float32* __restrict__ ptr, int64 num) {
|
||||
int64 i = threadIdx.x + blockIdx.x * (int64)blockDim.x;
|
||||
if (i<num) {
|
||||
#if JT_CHECK_NAN == 2
|
||||
if (::isnan(ptr[i]))
|
||||
#else
|
||||
if (::isnan(ptr[i]) || ::isinf(ptr[i]))
|
||||
#endif
|
||||
__trap();
|
||||
}
|
||||
}
|
||||
|
@ -34,7 +42,11 @@ __global__ void _check_nan_float32(float32* __restrict__ ptr, int64 num) {
|
|||
__global__ void _check_nan_float64(float64* __restrict__ ptr, int64 num) {
|
||||
int64 i = threadIdx.x + blockIdx.x * (int64)blockDim.x;
|
||||
if (i<num) {
|
||||
#if JT_CHECK_NAN == 2
|
||||
if (::isnan(ptr[i]))
|
||||
#else
|
||||
if (::isnan(ptr[i]) || ::isinf(ptr[i]))
|
||||
#endif
|
||||
__trap();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -195,9 +195,11 @@ constexpr int amp_keep_reduce = 4;
|
|||
constexpr int amp_keep_white = 8;
|
||||
constexpr int amp_array_prefer = 16;
|
||||
|
||||
inline NanoString float_dtype(int dsize_) {
|
||||
if (amp_reg & amp_prefer32) return ns_float32;
|
||||
if (amp_reg & amp_prefer16) return ns_float16;
|
||||
inline NanoString float_dtype(int dsize_, bool has_scalar=false) {
|
||||
if (!has_scalar) {
|
||||
if (amp_reg & amp_prefer32) return ns_float32;
|
||||
if (amp_reg & amp_prefer16) return ns_float16;
|
||||
}
|
||||
return (dsize_ == 3) ? ns_float64 :
|
||||
(dsize_ == 2 ) ? ns_float32 : ns_float16;
|
||||
}
|
||||
|
@ -208,25 +210,29 @@ inline NanoString int_dtype(int dsize_) {
|
|||
(dsize_ == 1) ? ns_int16 : ns_int8;
|
||||
}
|
||||
|
||||
inline NanoString dtype_infer(NanoString x, NanoString y) {
|
||||
inline NanoString dtype_infer(NanoString x, NanoString y, bool xscalar=false, bool yscalar=false) {
|
||||
int dsize_ = std::max(x.dsize_(), y.dsize_());
|
||||
if (xscalar) dsize_ = y.dsize_();
|
||||
if (yscalar) dsize_ = x.dsize_();
|
||||
bool is_float = x.is_float() || y.is_float();
|
||||
if (is_float)
|
||||
return float_dtype(dsize_);
|
||||
return float_dtype(dsize_, xscalar||yscalar);
|
||||
else {
|
||||
return int_dtype(dsize_);
|
||||
}
|
||||
}
|
||||
|
||||
inline NanoString binary_dtype_infer(NanoString op, NanoString x, NanoString y) {
|
||||
inline NanoString binary_dtype_infer(NanoString op, NanoString x, NanoString y, bool xscalar=false, bool yscalar=false) {
|
||||
if (op.is_bool()) return ns_bool;
|
||||
int dsize_ = std::max(x.dsize_(), y.dsize_());
|
||||
if (xscalar) dsize_ = y.dsize_();
|
||||
if (yscalar) dsize_ = x.dsize_();
|
||||
bool is_float = !op.is_int() &&
|
||||
(x.is_float() || y.is_float() || op.is_float());
|
||||
if (is_float) {
|
||||
if (op.is_white() && !(amp_reg & amp_keep_white))
|
||||
return (dsize_ == 3) ? ns_float64 : ns_float32;
|
||||
return float_dtype(dsize_);
|
||||
return float_dtype(dsize_, xscalar||yscalar);
|
||||
} else {
|
||||
if (x.is_bool() && y.is_bool()) return ns_bool;
|
||||
return int_dtype(dsize_);
|
||||
|
|
|
@ -45,6 +45,7 @@ struct NodeFlags {
|
|||
_needed_by_backward=_n+3,
|
||||
_out_hint=_n+4,
|
||||
_th_require_grad=_n+5,
|
||||
_is_scalar=_n+5,
|
||||
|
||||
// op related flags
|
||||
// bit0: support cpu
|
||||
|
|
|
@ -53,6 +53,7 @@ ArrayOp::ArrayOp(ArrayArgs&& args) {
|
|||
NanoVector shape = output->shape;
|
||||
if (shape.size() == 1 && shape[0] == 1) {
|
||||
output->flags.set(NodeFlags::_force_fuse);
|
||||
output->flags.set(NodeFlags::_is_scalar);
|
||||
set_type(OpType::element);
|
||||
}
|
||||
#ifdef HAS_CUDA
|
||||
|
|
|
@ -445,7 +445,7 @@ BinaryOp::BinaryOp(Var* x, Var* y, NanoString op) : x(x), y(y) {
|
|||
set_type(OpType::element);
|
||||
ns = op;
|
||||
ASSERT(ns.is_binary());
|
||||
z = create_output(x->shape, binary_dtype_infer(op, x->ns, y->ns));
|
||||
z = create_output(x->shape, binary_dtype_infer(op, x->ns, y->ns, x->flags.get(NodeFlags::_is_scalar), y->flags.get(NodeFlags::_is_scalar)));
|
||||
bool bin = ns.get(NanoString::_no_need_back_in);
|
||||
bool bout = ns.get(NanoString::_no_need_back_out);
|
||||
if (bin || bout) {
|
||||
|
|
|
@ -163,6 +163,7 @@ void BroadcastToOp::infer_shape() {
|
|||
NanoVector zshape;
|
||||
for (int i=0; i<zdim; i++) zshape.push_back(zz[i]);
|
||||
z->set_shape(zshape);
|
||||
z->flags.set(NodeFlags::_is_scalar, x->flags.get(NodeFlags::_is_scalar));
|
||||
LOGvvv << "Broadcast x(" >> x >> ") shape" << yshapes << "-> z(" >> z >> ")";
|
||||
}
|
||||
|
||||
|
|
|
@ -23,8 +23,6 @@ static auto make_ternary = get_op_info("ternary")
|
|||
static auto make_number = get_op_info("number")
|
||||
.get_constructor<VarPtr, float, Var*>();
|
||||
|
||||
NanoString binary_dtype_infer(NanoString op, Var* dx, Var* dy);
|
||||
|
||||
unordered_set<string> reduce_ops = {
|
||||
/**
|
||||
Returns the maximum elements in the input.
|
||||
|
@ -310,6 +308,8 @@ void ReduceOp::infer_shape() {
|
|||
keepdims_mask |= 1;
|
||||
}
|
||||
y->set_shape(yshape);
|
||||
if (yshape.size() == 1 && y->num == 1)
|
||||
y->flags.set(NodeFlags::_is_scalar);
|
||||
}
|
||||
|
||||
VarPtr ReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||
|
|
|
@ -42,7 +42,7 @@ TernaryOp::TernaryOp(Var* cond, Var* x, Var* y) : cond(cond), x(x), y(y) {
|
|||
if (x->dtype() == y->dtype()) {
|
||||
z = create_output(nullptr, x->dtype());
|
||||
} else {
|
||||
z = create_output(nullptr, dtype_infer(x->ns, y->ns));
|
||||
z = create_output(nullptr, dtype_infer(x->ns, y->ns, x->flags.get(NodeFlags::_is_scalar), y->flags.get(NodeFlags::_is_scalar)));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -854,6 +854,7 @@ UnaryOp::UnaryOp(Var* x, NanoString op) : x(x) {
|
|||
} else
|
||||
dtype = unary_dtype_infer(ns, x->ns);
|
||||
y = create_output(nullptr, dtype);
|
||||
y->flags.set(NodeFlags::_is_scalar, x->flags.get(NodeFlags::_is_scalar));
|
||||
bool bin = ns.get(NanoString::_no_need_back_in);
|
||||
bool bout = ns.get(NanoString::_no_need_back_out);
|
||||
if (bin || bout) {
|
||||
|
|
|
@ -119,6 +119,7 @@ ArrayOp::ArrayOp(PyObject* obj) {
|
|||
int64 size = output->size;
|
||||
if (shape.size() == 1 && shape[0] == 1) {
|
||||
output->flags.set(NodeFlags::_force_fuse);
|
||||
output->flags.set(NodeFlags::_is_scalar);
|
||||
set_type(OpType::element);
|
||||
}
|
||||
void* host_ptr = nullptr;
|
||||
|
|
|
@ -196,6 +196,25 @@ DEF_IS(MemInfo, const T&) from_py_object(PyObject* obj) {
|
|||
return GET_RAW_PTR(T, obj);
|
||||
}
|
||||
|
||||
// MemInfo
|
||||
struct ZipFile;
|
||||
EXTERN_LIB PyTypeObject PyjtZipFile;
|
||||
DEF_IS(ZipFile, bool) is_type(PyObject* obj) {
|
||||
return Py_TYPE(obj) == &PyjtZipFile;
|
||||
}
|
||||
|
||||
|
||||
DEF_IS(ZipFile, PyObject*) to_py_object(const T& a) {
|
||||
PyObjHolder obj(_PyObject_New(&PyjtZipFile));
|
||||
auto ptr = GET_RAW_PTR(T, obj.obj);
|
||||
new (ptr) T(a);
|
||||
return obj.release();
|
||||
}
|
||||
|
||||
DEF_IS(ZipFile, const T&) from_py_object(PyObject* obj) {
|
||||
return GET_RAW_PTR(T, obj);
|
||||
}
|
||||
|
||||
|
||||
// NanoString
|
||||
struct NanoString;
|
||||
|
|
|
@ -251,12 +251,21 @@ struct VarHolder {
|
|||
*/
|
||||
// @pyjt(__get__data)
|
||||
inline DataView data() {
|
||||
sync(true);
|
||||
sync(true, false);
|
||||
#ifdef HAS_CUDA
|
||||
migrate_to_cpu(var, exe.allocator);
|
||||
#endif
|
||||
return {this, var->mem_ptr, var->shape, var->dtype()};
|
||||
}
|
||||
|
||||
// @pyjt(__get__raw_ptr)
|
||||
inline uint64 raw_ptr() {
|
||||
sync(true, false);
|
||||
#ifdef HAS_CUDA
|
||||
migrate_to_cpu(var, exe.allocator);
|
||||
#endif
|
||||
return (uint64)var->mem_ptr;
|
||||
}
|
||||
|
||||
/**
|
||||
* returns the Python number if the Var contains only one element.
|
||||
|
|
|
@ -280,6 +280,13 @@ class TestFP16(unittest.TestCase):
|
|||
a.half()
|
||||
assert a.weight.dtype == "float16"
|
||||
|
||||
def test_scalar(self):
|
||||
a = jt.float16([1,2,3])
|
||||
assert (a*1).dtype == "float16"
|
||||
assert (a*jt.float16([1,2,3])).dtype == "float16"
|
||||
assert (a*jt.float32([1,2,3])).dtype == "float32"
|
||||
assert (a*jt.float32([1,2,3]).sum()).dtype == "float16"
|
||||
assert jt.int([0,1,0]).ternary(a, jt.float32(1)).dtype == "float16"
|
||||
|
||||
|
||||
@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found")
|
||||
|
|
|
@ -9,21 +9,6 @@ from typing import Any, BinaryIO, cast, Dict, Optional, Type, Tuple, Union, IO,
|
|||
|
||||
loaded_storages = {}
|
||||
deserialized_objects = {}
|
||||
def _is_zipfile(fn):
|
||||
f = open(fn, "rb")
|
||||
read_bytes = []
|
||||
start = f.tell()
|
||||
|
||||
byte = f.read(1)
|
||||
while byte != "":
|
||||
read_bytes.append(byte)
|
||||
if len(read_bytes) == 4:
|
||||
break
|
||||
byte = f.read(1)
|
||||
f.seek(start)
|
||||
|
||||
local_header_magic_number = [b'P', b'K', b'\x03', b'\x04']
|
||||
return read_bytes == local_header_magic_number
|
||||
|
||||
def _maybe_decode_ascii(bytes_str: Union[bytes, str]) -> str:
|
||||
if isinstance(bytes_str, bytes):
|
||||
|
@ -31,8 +16,8 @@ def _maybe_decode_ascii(bytes_str: Union[bytes, str]) -> str:
|
|||
return bytes_str
|
||||
|
||||
def load_tensor(contents, dtype, numel, key, location):
|
||||
name = os.path.join("archive", "data", str(key))
|
||||
loaded_storages[key] = np.frombuffer(contents[name], dtype).copy()
|
||||
name = os.path.join(prefix, "data", str(key))
|
||||
loaded_storages[key] = contents.read_var(name, dtype)
|
||||
|
||||
def get_dtype_size(dtype):
|
||||
dtype = dtype.__str__()
|
||||
|
@ -208,22 +193,31 @@ def persistent_load_direct(saved_id):
|
|||
raise RuntimeError("Unknown saved id type: %s" % saved_id[0])
|
||||
|
||||
def load_pytorch(fn_name):
|
||||
global contents, deserialized_objects, loaded_storages
|
||||
import jittor as jt
|
||||
global contents, deserialized_objects, loaded_storages, prefix
|
||||
loaded_storages = {}
|
||||
deserialized_objects = {}
|
||||
if not (fn_name.endswith(".pth") or fn_name.endswith(".pt") or fn_name.endswith(".bin")):
|
||||
print("This function is designed to load pytorch pth format files.")
|
||||
return None
|
||||
else:
|
||||
if _is_zipfile(fn_name):
|
||||
contents = jt.ZipFile(fn_name)
|
||||
if contents.valid():
|
||||
loaded_storages = {}
|
||||
deserialized_objects = {}
|
||||
contents = extract_zip(fn_name)
|
||||
data_file = io.BytesIO(contents['archive/data.pkl'])
|
||||
pickle_load_args = {'encoding': 'utf-8'}
|
||||
unpickler = UnpicklerWrapper(data_file, **pickle_load_args)
|
||||
unpickler.persistent_load = persistent_load
|
||||
result = unpickler.load()
|
||||
for name in contents.list():
|
||||
if "data.pkl" in name:
|
||||
prefix = name[:-8]
|
||||
break
|
||||
else:
|
||||
raise RuntimeError(f"zipfile <{fn_name}> format error, data.pkl not found")
|
||||
with jt.flag_scope(use_cuda=0):
|
||||
data_file = contents.read_var(prefix+"data.pkl").data.tobytes()
|
||||
data_file = io.BytesIO(data_file)
|
||||
pickle_load_args = {'encoding': 'utf-8'}
|
||||
unpickler = UnpicklerWrapper(data_file, **pickle_load_args)
|
||||
unpickler.persistent_load = persistent_load
|
||||
result = unpickler.load()
|
||||
else:
|
||||
deserialized_objects = {}
|
||||
f = open(fn_name, "rb")
|
||||
|
|
|
@ -0,0 +1,271 @@
|
|||
import pickle
|
||||
import os
|
||||
import io
|
||||
import shutil
|
||||
from zipfile import ZipFile
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
from typing import Any, BinaryIO, cast, Dict, Optional, Type, Tuple, Union, IO, List
|
||||
|
||||
loaded_storages = {}
|
||||
deserialized_objects = {}
|
||||
|
||||
def _maybe_decode_ascii(bytes_str: Union[bytes, str]) -> str:
|
||||
if isinstance(bytes_str, bytes):
|
||||
return bytes_str.decode('ascii')
|
||||
return bytes_str
|
||||
|
||||
def load_tensor(contents, dtype, numel, key, location):
|
||||
name = os.path.join(prefix, "data", str(key))
|
||||
loaded_storages[key] = contents.read_var(name, dtype)
|
||||
|
||||
def get_dtype_size(dtype):
|
||||
dtype = dtype.__str__()
|
||||
if dtype == "float32" or dtype == "int32":
|
||||
return 4
|
||||
if dtype == "float64" or dtype == "int64":
|
||||
return 8
|
||||
if dtype == "float16" or dtype == "int16":
|
||||
return 2
|
||||
return 1
|
||||
|
||||
def persistent_load(saved_id):
|
||||
global contents
|
||||
assert isinstance(saved_id, tuple)
|
||||
typename = _maybe_decode_ascii(saved_id[0])
|
||||
data = saved_id[1:]
|
||||
assert typename == 'storage', \
|
||||
f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"
|
||||
storage_type, key, location, numel = data
|
||||
dtype = storage_type.dtype
|
||||
if key not in loaded_storages:
|
||||
nbytes = numel
|
||||
load_tensor(contents, dtype, nbytes, key, _maybe_decode_ascii(location))
|
||||
return loaded_storages[key]
|
||||
|
||||
def _dtype_to_storage_type_map():
|
||||
return {
|
||||
np.float16: 'HalfStorage',
|
||||
np.float32: 'FloatStorage',
|
||||
np.int64: 'LongStorage',
|
||||
np.int32: 'IntStorage',
|
||||
np.int16: 'ShortStorage',
|
||||
np.int8: 'CharStorage'
|
||||
}
|
||||
|
||||
def _storage_type_to_dtype_map():
|
||||
dtype_map = {
|
||||
val: key for key, val in _dtype_to_storage_type_map().items()}
|
||||
return dtype_map
|
||||
|
||||
def _get_dtype_from_pickle_storage_type(pickle_storage_type: str):
|
||||
try:
|
||||
return _storage_type_to_dtype_map()[pickle_storage_type]
|
||||
except KeyError:
|
||||
raise KeyError(
|
||||
f'pickle storage type "{pickle_storage_type}" is not recognized')
|
||||
|
||||
class StorageType():
|
||||
def __init__(self, name):
|
||||
self.dtype = _get_dtype_from_pickle_storage_type(name)
|
||||
|
||||
def __str__(self):
|
||||
return f'StorageType(dtype={self.dtype})'
|
||||
|
||||
def jittor_rebuild(storage, storage_offset, size, stride, requires_grad, backward_hooks):
|
||||
if len(size) == 0:
|
||||
return jt.array(storage)
|
||||
record_size = np.prod(size)
|
||||
return jt.array(storage[:record_size]).reshape(size)
|
||||
|
||||
def jittor_rebuild_var(data, requires_grad, backward_hooks):
|
||||
v = jt.array(data)
|
||||
v.requires_grad = requires_grad
|
||||
return v
|
||||
|
||||
class UnpicklerWrapper(pickle.Unpickler): # type: ignore[name-defined]
|
||||
def find_class(self, mod_name, name):
|
||||
if type(name) is str and 'Storage' in name:
|
||||
try:
|
||||
return StorageType(name)
|
||||
except KeyError:
|
||||
pass
|
||||
if type(name) is str and '_rebuild_tensor_v2' in name:
|
||||
return super().find_class("jittor_utils.load_pytorch", "jittor_rebuild")
|
||||
if type(name) is str and '_rebuild_parameter' in name:
|
||||
return super().find_class("jittor_utils.load_pytorch", "jittor_rebuild_var")
|
||||
|
||||
return super().find_class(mod_name, name)
|
||||
|
||||
class ArrayWrapper:
|
||||
def __init__(self, storage, stride=None, size=None, requires_grad=None):
|
||||
self.requires_grad = requires_grad
|
||||
self.size = size
|
||||
self.storage = storage
|
||||
self.stride = stride
|
||||
|
||||
def __str__(self):
|
||||
return self.storage.__str__()
|
||||
|
||||
def jittor_rebuild_direct(storage, storage_offset, size, stride, requires_grad, backward_hooks):
|
||||
if len(size) == 0:
|
||||
return ArrayWrapper(storage, stride=stride, size=size)
|
||||
storage.reshape(size)
|
||||
return ArrayWrapper(storage, stride=stride, size=size)
|
||||
|
||||
def jittor_rebuild_var_direct(data, requires_grad, backward_hooks):
|
||||
v = ArrayWrapper(storage, requires_grad=requires_grad)
|
||||
return v
|
||||
|
||||
class DirectUnpicklerWrapper(pickle.Unpickler): # type: ignore[name-defined]
|
||||
def find_class(self, mod_name, name):
|
||||
if type(name) is str and 'Storage' in name:
|
||||
try:
|
||||
return StorageType(name)
|
||||
except KeyError:
|
||||
pass
|
||||
if type(name) is str and '_rebuild_tensor_v2' in name:
|
||||
return super().find_class("jittor_utils.load_pytorch", "jittor_rebuild_direct")
|
||||
if type(name) is str and '_rebuild_parameter' in name:
|
||||
return super().find_class("jittor_utils.load_pytorch", "jittor_rebuild_var_direct")
|
||||
return super().find_class(mod_name, name)
|
||||
|
||||
def _check_seekable(f) -> bool:
|
||||
def raise_err_msg(patterns, e):
|
||||
for p in patterns:
|
||||
if p in str(e):
|
||||
msg = (str(e) + ". You can only load from a file that is seekable."
|
||||
+ " Please pre-load the data into a buffer like io.BytesIO and"
|
||||
+ " try to load from it instead.")
|
||||
raise type(e)(msg)
|
||||
raise e
|
||||
|
||||
try:
|
||||
f.seek(f.tell())
|
||||
return True
|
||||
except (io.UnsupportedOperation, AttributeError) as e:
|
||||
raise_err_msg(["seek", "tell"], e)
|
||||
return False
|
||||
|
||||
def extract_zip(input_zip):
|
||||
input_zip = ZipFile(input_zip)
|
||||
return {name: input_zip.read(name) for name in input_zip.namelist()}
|
||||
|
||||
def _is_compressed_file(f):
|
||||
compress_modules = ['gzip']
|
||||
try:
|
||||
return f.__module__ in compress_modules
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
def _should_read_directly(f):
|
||||
if _is_compressed_file(f):
|
||||
return False
|
||||
try:
|
||||
return f.fileno() >= 0
|
||||
except io.UnsupportedOperation:
|
||||
return False
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
def persistent_load_direct(saved_id):
|
||||
global deserialized_objects
|
||||
assert isinstance(saved_id, tuple)
|
||||
typename = _maybe_decode_ascii(saved_id[0])
|
||||
data = saved_id[1:]
|
||||
if typename == 'module':
|
||||
# Ignore containers that don't have any sources saved
|
||||
return data[0]
|
||||
elif typename == 'storage':
|
||||
data_type, root_key, location, size, view_metadata = data
|
||||
location = _maybe_decode_ascii(location)
|
||||
if root_key not in deserialized_objects:
|
||||
deserialized_objects[root_key] = np.zeros(size, dtype=data_type)
|
||||
storage = deserialized_objects[root_key]
|
||||
if view_metadata is not None:
|
||||
view_key, offset, view_size = view_metadata
|
||||
if view_key not in deserialized_objects:
|
||||
deserialized_objects[view_key] = storage[offset:offset + view_size]
|
||||
return deserialized_objects[view_key]
|
||||
else:
|
||||
return storage
|
||||
else:
|
||||
raise RuntimeError("Unknown saved id type: %s" % saved_id[0])
|
||||
|
||||
def load_pytorch(fn_name):
|
||||
import jittor as jt
|
||||
global contents, deserialized_objects, loaded_storages, prefix
|
||||
loaded_storages = {}
|
||||
deserialized_objects = {}
|
||||
if not (fn_name.endswith(".pth") or fn_name.endswith(".pt") or fn_name.endswith(".bin")):
|
||||
print("This function is designed to load pytorch pth format files.")
|
||||
return None
|
||||
else:
|
||||
contents = jt.ZipFile(fn_name)
|
||||
if contents.valid():
|
||||
loaded_storages = {}
|
||||
deserialized_objects = {}
|
||||
for name in contents.list():
|
||||
if "data.pkl" in name:
|
||||
prefix = name[:-8]
|
||||
break
|
||||
else:
|
||||
raise RuntimeError(f"zipfile <{fn_name}> format error, data.pkl not found")
|
||||
with jt.flag_scope(use_cuda=0):
|
||||
print("load??", fn_name)
|
||||
data_file = contents.read_var(prefix+"data.pkl").data.tobytes()
|
||||
data_file = io.BytesIO(data_file)
|
||||
pickle_load_args = {'encoding': 'utf-8'}
|
||||
unpickler = UnpicklerWrapper(data_file, **pickle_load_args)
|
||||
unpickler.persistent_load = persistent_load
|
||||
result = unpickler.load()
|
||||
else:
|
||||
deserialized_objects = {}
|
||||
f = open(fn_name, "rb")
|
||||
f_should_read_directly = _should_read_directly(f)
|
||||
MAGIC_NUMBER = 0x1950a86a20f9469cfc6c
|
||||
PROTOCOL_VERSION = 1001
|
||||
pickle_load_args = {'encoding': 'utf-8'}
|
||||
magic_number = pickle.load(f, **pickle_load_args)
|
||||
if magic_number != MAGIC_NUMBER:
|
||||
raise RuntimeError("Invalid magic number; corrupt file?")
|
||||
protocol_version = pickle.load(f, **pickle_load_args)
|
||||
if PROTOCOL_VERSION != protocol_version:
|
||||
raise RuntimeError("Invalid protocal version.")
|
||||
_sys_info = pickle.load(f, **pickle_load_args)
|
||||
unpickler = DirectUnpicklerWrapper(f, **pickle_load_args)
|
||||
unpickler.persistent_load = persistent_load_direct
|
||||
result = unpickler.load()
|
||||
offset = f.tell() if f_should_read_directly else None
|
||||
deserialized_storage_keys = pickle.load(f, **pickle_load_args)
|
||||
f.read(8)
|
||||
for key in deserialized_storage_keys:
|
||||
assert key in deserialized_objects
|
||||
dtype = deserialized_objects[key].dtype
|
||||
size = deserialized_objects[key].size * get_dtype_size(dtype)
|
||||
byte_data = f.read(size)
|
||||
deserialized_objects[key][:] = np.frombuffer(byte_data, dtype).copy()
|
||||
f.read(8)
|
||||
if offset is not None:
|
||||
offset = f.tell()
|
||||
for key, params in result.items():
|
||||
requires_grad = params.requires_grad
|
||||
shape = params.size
|
||||
result[key] = jt.array(params.storage)
|
||||
if shape is not None and len(shape) > 0:
|
||||
if len(params.stride) > 1:
|
||||
eval_list = []
|
||||
for idx in range(len(params.stride)):
|
||||
eval_list.append(f"@e0({idx}) * i{idx}")
|
||||
evals = "+".join(eval_list)
|
||||
result[key] = result[key].reindex(params.size, [evals], extras=[jt.array(params.stride)])
|
||||
else:
|
||||
result[key] = result[key].reshape(shape)
|
||||
if requires_grad is not None:
|
||||
result[key].requires_grad = requires_grad
|
||||
return result
|
||||
|
||||
if __name__ == "__main__":
|
||||
result = load_pytorch("van_base.pth")
|
||||
for key, val in result.items():
|
||||
print(key, val.shape)
|
Loading…
Reference in New Issue