mirror of https://github.com/Jittor/Jittor
load pth
This commit is contained in:
parent
8430c48ad5
commit
358e4b1b82
|
@ -465,7 +465,15 @@ def display_memory_info():
|
|||
core.display_memory_info(fileline)
|
||||
|
||||
def load(path):
|
||||
model_dict = safeunpickle(path)
|
||||
if path.endswith(".pth"):
|
||||
try:
|
||||
dirty_fix_pytorch_runtime_error()
|
||||
import torch
|
||||
except:
|
||||
raise RuntimeError("pytorch need to be installed when load pth format.")
|
||||
model_dict = torch.load(path, map_location=torch.device('cpu'))
|
||||
else:
|
||||
model_dict = safeunpickle(path)
|
||||
return model_dict
|
||||
|
||||
def _uniq(x):
|
||||
|
@ -678,15 +686,7 @@ class Module:
|
|||
safepickle(params_dict, path)
|
||||
|
||||
def load(self, path):
|
||||
if path.endswith(".pth"):
|
||||
try:
|
||||
dirty_fix_pytorch_runtime_error()
|
||||
import torch
|
||||
except:
|
||||
raise RuntimeError("pytorch need to be installed when load pth format.")
|
||||
self.load_parameters(torch.load(path, map_location=torch.device('cpu')))
|
||||
return
|
||||
self.load_parameters(safeunpickle(path))
|
||||
self.load_parameters(load(path))
|
||||
|
||||
def eval(self):
|
||||
def callback(parents, k, v, n):
|
||||
|
|
Loading…
Reference in New Issue