This commit is contained in:
Dun Liang 2020-12-14 23:04:20 +08:00
parent 8430c48ad5
commit 358e4b1b82
1 changed files with 10 additions and 10 deletions

View File

@ -465,7 +465,15 @@ def display_memory_info():
core.display_memory_info(fileline)
def load(path):
model_dict = safeunpickle(path)
if path.endswith(".pth"):
try:
dirty_fix_pytorch_runtime_error()
import torch
except:
raise RuntimeError("pytorch need to be installed when load pth format.")
model_dict = torch.load(path, map_location=torch.device('cpu'))
else:
model_dict = safeunpickle(path)
return model_dict
def _uniq(x):
@ -678,15 +686,7 @@ class Module:
safepickle(params_dict, path)
def load(self, path):
if path.endswith(".pth"):
try:
dirty_fix_pytorch_runtime_error()
import torch
except:
raise RuntimeError("pytorch need to be installed when load pth format.")
self.load_parameters(torch.load(path, map_location=torch.device('cpu')))
return
self.load_parameters(safeunpickle(path))
self.load_parameters(load(path))
def eval(self):
def callback(parents, k, v, n):