mirror of https://github.com/Jittor/Jittor
Update Load
This commit is contained in:
parent
4ae9171578
commit
f045e196c1
|
@ -428,12 +428,14 @@ def random(shape, dtype="float32", type="uniform"):
|
|||
jt.Var([[0.96788853 0.28334728 0.30482838]
|
||||
[0.46107793 0.62798643 0.03457401]], dtype=float32)
|
||||
'''
|
||||
# TODO: move those code to core
|
||||
if dtype in ["float16", "bfloat16"]:
|
||||
# TODO: make curand support fp16
|
||||
ret = ops.random(shape, "float32", type).cast(dtype)
|
||||
else:
|
||||
ret = ops.random(shape, dtype, type)
|
||||
|
||||
ret = ops.random(shape, "float32", type)
|
||||
## TODO: move those code to core
|
||||
#if dtype in ["float16", "bfloat16"]:
|
||||
# # TODO: make curand support fp16
|
||||
# ret = ops.random(shape, "float32", type).cast(dtype)
|
||||
#else:
|
||||
# ret = ops.random(shape, dtype, type)
|
||||
amp_reg = jt.flags.amp_reg
|
||||
if amp_reg:
|
||||
if amp_reg & 16:
|
||||
|
|
|
@ -347,9 +347,9 @@ def stack(x, dim=0):
|
|||
'''
|
||||
assert isinstance(x, Sequence)
|
||||
if len(x) < 2:
|
||||
return x[0].unsqueeze(dim)
|
||||
return jt.Var(x[0]).unsqueeze(dim)
|
||||
|
||||
res = [x_.unsqueeze(dim) for x_ in x]
|
||||
res = [jt.Var(x_).unsqueeze(dim) for x_ in x]
|
||||
return jt.concat(res, dim=dim)
|
||||
jt.Var.stack = stack
|
||||
|
||||
|
|
|
@ -142,6 +142,10 @@ struct NanoString {
|
|||
// @pyjt(is_int)
|
||||
inline bool is_int() const { return get(_int); }
|
||||
inline bool is_unsigned() const { return get(_unsigned); }
|
||||
// @pyjt(is_floating_point)
|
||||
inline bool is_floating_point() const { return get(_float); }
|
||||
// @pyjt(is_complex)
|
||||
inline bool is_complex() const { return false; }
|
||||
// @pyjt(is_float)
|
||||
inline bool is_float() const { return get(_float); }
|
||||
inline ns_t is_white() const { return get(_white_list); }
|
||||
|
|
|
@ -265,7 +265,13 @@ def load_pytorch(fn_name):
|
|||
else:
|
||||
raise RuntimeError(f"zipfile <{fn_name}> format error, data.pkl not found")
|
||||
|
||||
data_file = contents.read_var(prefix+"data.pkl").data.tobytes()
|
||||
data_file = contents.read_var(prefix+"data.pkl")
|
||||
#import pdb; pdb.set_trace();
|
||||
#print(data_file)
|
||||
if data_file.dtype == "uint8":
|
||||
data_file = data_file.numpy().tobytes()
|
||||
else:
|
||||
data_file = data_file.data.tobytes()
|
||||
data_file = io.BytesIO(data_file)
|
||||
pickle_load_args = {'encoding': 'utf-8'}
|
||||
unpickler = UnpicklerWrapper(data_file, **pickle_load_args)
|
||||
|
|
Loading…
Reference in New Issue