Update Load

This commit is contained in:
zhangjiapeng 2024-04-01 18:00:34 +08:00
parent 4ae9171578
commit f045e196c1
4 changed files with 21 additions and 9 deletions

View File

@ -428,12 +428,14 @@ def random(shape, dtype="float32", type="uniform"):
jt.Var([[0.96788853 0.28334728 0.30482838]
[0.46107793 0.62798643 0.03457401]], dtype=float32)
'''
# TODO: move those code to core
if dtype in ["float16", "bfloat16"]:
# TODO: make curand support fp16
ret = ops.random(shape, "float32", type).cast(dtype)
else:
ret = ops.random(shape, dtype, type)
ret = ops.random(shape, "float32", type)
## TODO: move those code to core
#if dtype in ["float16", "bfloat16"]:
# # TODO: make curand support fp16
# ret = ops.random(shape, "float32", type).cast(dtype)
#else:
# ret = ops.random(shape, dtype, type)
amp_reg = jt.flags.amp_reg
if amp_reg:
if amp_reg & 16:

View File

@ -347,9 +347,9 @@ def stack(x, dim=0):
'''
assert isinstance(x, Sequence)
if len(x) < 2:
return x[0].unsqueeze(dim)
return jt.Var(x[0]).unsqueeze(dim)
res = [x_.unsqueeze(dim) for x_ in x]
res = [jt.Var(x_).unsqueeze(dim) for x_ in x]
return jt.concat(res, dim=dim)
jt.Var.stack = stack

View File

@ -142,6 +142,10 @@ struct NanoString {
// @pyjt(is_int)
inline bool is_int() const { return get(_int); }
inline bool is_unsigned() const { return get(_unsigned); }
// @pyjt(is_floating_point)
inline bool is_floating_point() const { return get(_float); }
// @pyjt(is_complex)
inline bool is_complex() const { return false; }
// @pyjt(is_float)
inline bool is_float() const { return get(_float); }
inline ns_t is_white() const { return get(_white_list); }

View File

@ -265,7 +265,13 @@ def load_pytorch(fn_name):
else:
raise RuntimeError(f"zipfile <{fn_name}> format error, data.pkl not found")
data_file = contents.read_var(prefix+"data.pkl").data.tobytes()
data_file = contents.read_var(prefix+"data.pkl")
#import pdb; pdb.set_trace();
#print(data_file)
if data_file.dtype == "uint8":
data_file = data_file.numpy().tobytes()
else:
data_file = data_file.data.tobytes()
data_file = io.BytesIO(data_file)
pickle_load_args = {'encoding': 'utf-8'}
unpickler = UnpicklerWrapper(data_file, **pickle_load_args)