mirror of https://github.com/Jittor/Jittor
fix master
This commit is contained in:
parent
c78db2a794
commit
4017b161d2
|
@ -1,430 +0,0 @@
|
|||
# import os
|
||||
# os.environ["FIX_TORCH_ERROR"] = "0"
|
||||
|
||||
# import jittor as jt
|
||||
# from jittor import *
|
||||
# from typing import Tuple
|
||||
|
||||
# org_int = int = type(1)
|
||||
# org_float = float = type(1.0)
|
||||
# org_bool = bool = type(True)
|
||||
|
||||
# import jtorch.compiler
|
||||
|
||||
# import jtorch_core
|
||||
# from jtorch_core import *
|
||||
|
||||
# device.__reduce__ = lambda self: (device, (self.type,))
|
||||
# device.__module__ = "jtorch"
|
||||
# jt.jittor_core.device = device
|
||||
|
||||
# def handle_dtype(args, kw, dtype):
|
||||
# def convert(x):
|
||||
# if isinstance(x, jt.Var):
|
||||
# return x.cast(dtype)
|
||||
# return x
|
||||
# if dtype is not None:
|
||||
# if args is not None:
|
||||
# if isinstance(args, (tuple,list)):
|
||||
# args = [ convert(a) for a in args ]
|
||||
# else:
|
||||
# args = convert(x)
|
||||
# if kw is not None:
|
||||
# kw = { k:convert(v) for k,v in kw.items() }
|
||||
# return args, kw
|
||||
|
||||
# def get_args_names(func):
|
||||
# import inspect
|
||||
# spec = inspect.getfullargspec(func)
|
||||
# return spec[0] + spec[4]
|
||||
|
||||
# def wrapper(func):
|
||||
# has_dtype = False
|
||||
# if hasattr(func, "__code__"):
|
||||
# has_dtype = "dtype" in get_args_names(func)
|
||||
# def inner(*args, **kw):
|
||||
# requires_grad = None
|
||||
# dtype = None
|
||||
# if "requires_grad" in kw:
|
||||
# requires_grad = kw["requires_grad"]
|
||||
# del kw["requires_grad"]
|
||||
# if not has_dtype and "dtype" in kw:
|
||||
# dtype = kw["dtype"]
|
||||
# del kw["dtype"]
|
||||
# if "device" in kw:
|
||||
# del kw["device"]
|
||||
# if 'pin_memory' in kw:
|
||||
# del kw['pin_memory']
|
||||
# args, kw = handle_dtype(args, kw, dtype)
|
||||
# ret = func(*args, **kw)
|
||||
# if isinstance(ret, jt.Var):
|
||||
# if requires_grad is not None:
|
||||
# ret.requires_grad = requires_grad
|
||||
# if dtype is not None:
|
||||
# ret.astype(dtype)
|
||||
# return ret
|
||||
# return inner
|
||||
|
||||
|
||||
# import inspect
|
||||
# _wrapper_keys = set(["shape", "start", "size"])
|
||||
# _wrapper_keys.add("x")
|
||||
# for k,v in list(globals().items()):
|
||||
# if callable(v) and not isinstance(v, type):
|
||||
# try:
|
||||
# spec = inspect.getfullargspec(v)
|
||||
# args_name = spec[0]
|
||||
# if len(args_name) and args_name[0] in _wrapper_keys:
|
||||
# globals()[k] = wrapper(v)
|
||||
# elif spec.varargs in _wrapper_keys:
|
||||
# globals()[k] = wrapper(v)
|
||||
# except:
|
||||
# pass
|
||||
|
||||
# def empty(*size, dtype=jt.float32, device=None, requires_grad=False):
|
||||
# if len(size) == 1 and not isinstance(size[0], org_int):
|
||||
# size = size[0]
|
||||
# return jt.empty(size, dtype)
|
||||
|
||||
# Tensor = Var
|
||||
|
||||
# Tensor.backward = lambda x: jtorch_core.backward(x)
|
||||
# Tensor.grad = property(grad_get, grad_set, grad_del)
|
||||
# Tensor.retains_grad = property(retain_grad_get, retain_grad_set)
|
||||
# def retain_grad(x:Tensor, value:bool=True):
|
||||
# x.retains_grad = value
|
||||
# return value
|
||||
# Tensor.retain_grad = retain_grad
|
||||
|
||||
# Tensor.dim = lambda self: self.ndim
|
||||
# Tensor.ndimension = lambda self: self.ndim
|
||||
# Tensor.nelement = lambda self: self.numel()
|
||||
# Tensor.cuda = lambda self: self
|
||||
# def device_get(x:Tensor):
|
||||
# return device("cpu") if not jt.has_cuda or not jt.flags.use_cuda else device("cuda")
|
||||
# Tensor.device = property(device_get)
|
||||
|
||||
# def argmax(x: Var, dim=None, keepdim: bool = False):
|
||||
# return jt.argmax(x, dim, keepdim)[0]
|
||||
# Tensor.argmax = argmax
|
||||
|
||||
# def tensor_type(x: Var, dtype=None, **kwargs):
|
||||
# if dtype:
|
||||
# return x.astype(dtype)
|
||||
# else:
|
||||
# return x.dtype
|
||||
# Tensor.type = tensor_type
|
||||
|
||||
# def is_floating_point(x: Var):
|
||||
# return "float" in str(x.dtype)
|
||||
# Tensor.is_floating_point = is_floating_point
|
||||
|
||||
# from . import autograd
|
||||
# from .autograd import *
|
||||
|
||||
# def tensor(data, *, dtype=None, device=None, requires_grad=False, pin_memory=False):
|
||||
# if isinstance(data,list):
|
||||
# data_list = []
|
||||
# check = True
|
||||
# for p in data:
|
||||
# if isinstance(p, Tensor) and p.numel()==1:
|
||||
# data_list.append(p.item())
|
||||
# elif isinstance(p, (org_int,org_float)):
|
||||
# data_list.append(p)
|
||||
# else:
|
||||
# check = False
|
||||
# break
|
||||
# if check:
|
||||
# data = data_list
|
||||
# return wrapper(array)(data, dtype=dtype, device=device, requires_grad=requires_grad, pin_memory=pin_memory)
|
||||
|
||||
# # tensor = wrapper(array)
|
||||
# from_numpy = wrapper(array)
|
||||
# strided = None
|
||||
|
||||
# def mod_zero_grad(self):
|
||||
# for p in self.parameters():
|
||||
# p.grad = None
|
||||
# Module.zero_grad = mod_zero_grad
|
||||
|
||||
# class ModuleMisc:
|
||||
# def parameters(self):
|
||||
# return iter(super().parameters())
|
||||
|
||||
# def load_state_dict(self, state_dict, strict=False):
|
||||
# return super().load_state_dict(state_dict)
|
||||
|
||||
# def to(self, device=None,dtype=None):
|
||||
# ''' do nothing but return its self'''
|
||||
# return self
|
||||
# def register_parameter(self,name,data):
|
||||
# self.name = data
|
||||
|
||||
# def buffers(self):
|
||||
# for _, buf in self.named_buffers():
|
||||
# yield buf
|
||||
|
||||
|
||||
# def make_module(cls):
|
||||
# class TMod(ModuleMisc, cls):
|
||||
# def __init__(self, *args, **kw):
|
||||
# dtype = None
|
||||
# if "dtype" in kw:
|
||||
# dtype = kw["dtype"]
|
||||
# del kw["dtype"]
|
||||
# self._dtype = dtype
|
||||
# with jt.flag_scope(th_mode=0):
|
||||
# if "device" in kw:
|
||||
# del kw["device"]
|
||||
# super().__init__(*args, **kw)
|
||||
# for k,v in self.__dict__.items():
|
||||
# if not k.startswith("_") and isinstance(v, Var) \
|
||||
# and v.requires_grad:
|
||||
# v.retain_grad()
|
||||
# if dtype is not None and isinstance(v, Var):
|
||||
# v.assign(v.cast(dtype))
|
||||
# def __call__(self, *args, **kw):
|
||||
# args, kw = handle_dtype(args, kw, self._dtype)
|
||||
# # if forward is override by user, call forward
|
||||
# if self.__class__.forward is not TMod.forward:
|
||||
# return self.forward(*args, **kw)
|
||||
# return self.execute(*args, **kw)
|
||||
# def forward(self, *args, **kw):
|
||||
# args, kw = handle_dtype(args, kw, self._dtype)
|
||||
# return self.execute(*args, **kw)
|
||||
|
||||
# @property
|
||||
# def training(self):
|
||||
# if not hasattr(self, "is_train"):
|
||||
# self.is_train = True
|
||||
# return self.is_train
|
||||
# @training.setter
|
||||
# def training(self, value):
|
||||
# self.is_train = value
|
||||
|
||||
# TMod.__name__ = cls.__name__
|
||||
# return TMod
|
||||
|
||||
# import jtorch.cuda
|
||||
# import jtorch.nn
|
||||
# from jtorch.nn import Module, Parameter
|
||||
# import jtorch.optim
|
||||
|
||||
# from jtorch.utils.dtype import Dtype, get_string_dtype
|
||||
|
||||
# def frombuffer(buffer: bytearray,
|
||||
# *,
|
||||
# dtype: Dtype,
|
||||
# count: int = -1,
|
||||
# offset: int = 0,
|
||||
# requires_grad: bool = True) -> Tensor:
|
||||
# dtype = get_string_dtype(dtype)
|
||||
# tensor = jt.array(np.frombuffer(buffer, dtype, count=count, offset=offset))
|
||||
# if requires_grad and tensor.dtype.is_float():
|
||||
# tensor.requires_grad = True
|
||||
# return tensor
|
||||
|
||||
# def conflict_wrapper(origin_func, new_func):
|
||||
# def wrapper(*args, **kw):
|
||||
# if jt.flags.th_mode:
|
||||
# return new_func(*args, **kw)
|
||||
# else:
|
||||
# return origin_func(*args, **kw)
|
||||
# return wrapper
|
||||
|
||||
# def min(*args, **kw):
|
||||
# dim = None
|
||||
# if len(args) >= 2 and isinstance(args[1], org_int):
|
||||
# dim = args[1]
|
||||
# elif "dim" in kw and isinstance(kw["dim"], org_int):
|
||||
# dim = kw["dim"]
|
||||
# if dim is not None:
|
||||
# k, v = jt.argmin(*args, **kw)
|
||||
# return v, k
|
||||
# elif len(args) == 2 and isinstance(args[1], jt.Var):
|
||||
# return jt.minimum(args[0], args[1])
|
||||
# else:
|
||||
# return jt.min(*args, **kw)
|
||||
# Tensor.min = conflict_wrapper(jt.min, min)
|
||||
|
||||
# def max(*args, **kw):
|
||||
# dim = None
|
||||
# if "dim" in kw:
|
||||
# x = kw["dim"]
|
||||
# if len(args) >= 2 and isinstance(args[1], org_int):
|
||||
# dim = args[1]
|
||||
# elif "dim" in kw and isinstance(kw["dim"], org_int):
|
||||
# dim = kw["dim"]
|
||||
# if dim is not None:
|
||||
# k, v = jt.argmax(*args, **kw)
|
||||
# return v, k
|
||||
# elif len(args) == 2 and isinstance(args[1], jt.Var):
|
||||
# return jt.maximum(args[0], args[1])
|
||||
# else:
|
||||
# return jt.max(*args, **kw)
|
||||
# Tensor.max = conflict_wrapper(jt.max, max)
|
||||
|
||||
# def argsort(*args, **kw):
|
||||
# k, v = jt.argsort(*args, **kw)
|
||||
# return k
|
||||
# Tensor.argsort = conflict_wrapper(jt.argsort, argsort)
|
||||
|
||||
# LongTensor = jt.int64
|
||||
# FloatTensor = jt.float
|
||||
# HalfTensor = jt.float16
|
||||
# BoolTensor = jt.bool
|
||||
# IntTensor = jt.int32
|
||||
|
||||
# class JDType:
|
||||
# def __init__(self, func, str):
|
||||
# self.func = func
|
||||
# self.str = str
|
||||
# self.__name__ = str.split(".")[-1]
|
||||
# def __call__(self, *args, **kw):
|
||||
# return self.func(*args, **kw)
|
||||
# def __str__(self):
|
||||
# return self.str
|
||||
# def is_floating_point(self):
|
||||
# return "float" in str(self.str)
|
||||
|
||||
# int8 = JDType(jt.int8, "torch.int8")
|
||||
# int16 = JDType(jt.int16, "torch.int16")
|
||||
# int = int32 = JDType(jt.int32, "torch.int32")
|
||||
# long = int64 = JDType(jt.int64, "torch.int64")
|
||||
|
||||
# half = float16 = JDType(jt.float16, "torch.float16")
|
||||
# float = float32 = JDType(jt.float32, "torch.float32")
|
||||
# double = float64 = JDType(jt.float64, "torch.float64")
|
||||
# bfloat16 = "bfloat16" # TODO
|
||||
# complex64 = "complex64" # TODO
|
||||
# complex128 = "complex128" # TODO
|
||||
# def get_JDtype(dtype):
|
||||
# if dtype=='float32' or dtype == jt.float32:
|
||||
# return float32
|
||||
# elif dtype=='float64' or dtype == jt.float64:
|
||||
# return float64
|
||||
# elif dtype=='float16' or dtype == jt.float16:
|
||||
# return float16
|
||||
# elif dtype=='int32' or dtype == jt.int32:
|
||||
# return int32
|
||||
# elif dtype=='int64' or dtype == jt.int64:
|
||||
# return int64
|
||||
# elif dtype=='int16' or dtype == jt.int16:
|
||||
# return int16
|
||||
# elif dtype=='int8' or dtype == jt.int8:
|
||||
# return int8
|
||||
# else:
|
||||
# raise Exception("dtype {} not supported".format(dtype))
|
||||
|
||||
# def load(path,**kwargs):
|
||||
# def _to_jittor(data):
|
||||
# if isinstance(data,dict):
|
||||
# return {k:_to_jittor(d) for k,d in data.items()}
|
||||
# if isinstance(data,list):
|
||||
# return [_to_jittor(d) for d in data]
|
||||
# if isinstance(data,np.ndarray):
|
||||
# return jt.array(data)
|
||||
# return data
|
||||
# data = jt.load(path)
|
||||
|
||||
# return _to_jittor(data)
|
||||
|
||||
# def is_tensor(x):
|
||||
# return isinstance(x, Tensor)
|
||||
|
||||
# manual_seed = jt.set_global_seed
|
||||
# jt.flags.amp_level = 3
|
||||
# Size = jt.NanoVector
|
||||
|
||||
# class Generator:
|
||||
# def __init__(self,*args,**kw) -> None:
|
||||
# self.seed = None
|
||||
# def manual_seed(self,seed):
|
||||
# self.seed = seed
|
||||
|
||||
|
||||
|
||||
# from . import fx
|
||||
|
||||
|
||||
# _default_type = "float32"
|
||||
|
||||
# def get_default_dtype():
|
||||
# return _default_type
|
||||
# def set_default_dtype(dtype):
|
||||
# global _default_type
|
||||
# _default_type = dtype
|
||||
|
||||
# dtype = JDType
|
||||
|
||||
# def div(x,y,rounding_mode="floor"):
|
||||
# assert rounding_mode == "floor"
|
||||
# z = (x / y)
|
||||
# if rounding_mode == "floor":
|
||||
# z = z.floor()
|
||||
# if x.dtype == "int32" and (isinstance(y,org_int) or y.dtype == "int32"):
|
||||
# z = z.int32()
|
||||
# return z
|
||||
|
||||
|
||||
# def randn(*args,**kw):
|
||||
# wrap_randn = wrapper(jt.randn)
|
||||
# generator = kw.get('generator',None)
|
||||
# kw.pop('generator',None)
|
||||
# if 'layout' in kw:
|
||||
# del kw['layout']
|
||||
# if generator is not None and generator.seed is not None:
|
||||
# jt.set_global_seed(generator.seed)
|
||||
# return wrap_randn(*args,**kw)
|
||||
|
||||
# def rand(*args,**kw):
|
||||
# print("rand")
|
||||
# wrap_rand = wrapper(jt.rand)
|
||||
# generator = kw.get('generator',None)
|
||||
# kw.pop('generator',None)
|
||||
# if 'layout' in kw:
|
||||
# del kw['layout']
|
||||
# if generator is not None and generator.seed is not None:
|
||||
# jt.set_global_seed(generator.seed)
|
||||
# return wrap_rand(*args,**kw)
|
||||
|
||||
|
||||
|
||||
# def set_default_tensor_type(t: type or str):
|
||||
# if isinstance(t, str):
|
||||
# info = t.split(".")
|
||||
# if len(info) == 3 and info[1] == 'cuda':
|
||||
# jt.flags.use_cuda = 1
|
||||
# #TODO: type
|
||||
|
||||
|
||||
# def clamp(x, min=None, max=None):
|
||||
# return jt.clamp(x, min, max)
|
||||
|
||||
|
||||
# def to(x,*args,**kw):
|
||||
# device = None
|
||||
# if len(args) == 1:
|
||||
# device = args[0]
|
||||
# if isinstance(device, jt.NanoString) or callable(device):
|
||||
# return jt.to(x,*args,**kw)
|
||||
# if 'cpu' in str(device):
|
||||
# args = []
|
||||
# device = kw.get("device",None)
|
||||
# if 'cpu' in str(device):
|
||||
# kw.pop('device',None)
|
||||
# print("to cpu")
|
||||
# # print(kw)
|
||||
# return jt.to(x,*args,**kw)
|
||||
# Tensor.to = conflict_wrapper(jt.to, to)
|
||||
|
||||
# mm = wrapper(jt.matmul)
|
||||
|
||||
# def _data_get(x):
|
||||
# return x
|
||||
|
||||
# def _data_set(x, value):
|
||||
# x.assign(value)
|
||||
|
||||
# Tensor.data = property(_data_get, _data_set)
|
||||
# Tensor.layout = None
|
|
@ -1,134 +0,0 @@
|
|||
import jittor as jt
|
||||
from jittor import Var
|
||||
from collections.abc import Sequence, Mapping
|
||||
|
||||
Variable = Var
|
||||
|
||||
class FunctionContext:
|
||||
def save_for_backward(self, *args):
|
||||
self.saved_tensors = args
|
||||
|
||||
class Function:
|
||||
''' Function Module for customized backward operations
|
||||
|
||||
Example 1 (Function can have multiple input and multiple output, and user
|
||||
can store value for backward computation)::
|
||||
|
||||
import jtorch
|
||||
from jtorch import Function
|
||||
|
||||
class MyFunc(Function):
|
||||
@staticmethod
|
||||
def forward(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
return x*y, x/y
|
||||
|
||||
@staticmethod
|
||||
def backward(self, grad0, grad1):
|
||||
return grad0 * self.y, grad1 * self.x
|
||||
|
||||
a = jtorch.array(3.0)
|
||||
a.requires_grad = True
|
||||
b = jtorch.array(4.0)
|
||||
b.requires_grad = True
|
||||
func = MyFunc.apply
|
||||
c,d = func(a, b)
|
||||
(c+d*3).backward()
|
||||
assert a.grad.data == 4
|
||||
assert b.grad.data == 9
|
||||
|
||||
Example 2(Function can return None for no gradiant, and gradiant
|
||||
can also be None)::
|
||||
|
||||
import jtorch
|
||||
from jtorch import Function
|
||||
|
||||
class MyFunc(Function):
|
||||
@staticmethod
|
||||
def forward(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
return x*y, x/y
|
||||
|
||||
@staticmethod
|
||||
def backward(self, grad0, grad1):
|
||||
assert grad1 is None
|
||||
return grad0 * self.y, None
|
||||
a = jt.array(3.0)
|
||||
a.requires_grad = True
|
||||
b = jt.array(4.0)
|
||||
b.requires_grad = True
|
||||
func = MyFunc.apply
|
||||
c,d = func(a, b)
|
||||
d.stop_grad()
|
||||
da, db = jt.grad(c+d*3, [a, b])
|
||||
assert da.data == 4
|
||||
assert db.data == 0
|
||||
|
||||
'''
|
||||
def __call__(self, *args):
|
||||
backup = args
|
||||
args = list(args)
|
||||
taped_inputs = []
|
||||
taped_outputs = []
|
||||
input_mask = [-1] * len(args)
|
||||
for i,v in enumerate(args):
|
||||
if isinstance(v, Var):
|
||||
if v.is_stop_grad():
|
||||
# -2 in input_mask represents it is stop_grad
|
||||
input_mask[i] = -2
|
||||
continue
|
||||
v = v.tape()
|
||||
input_mask[i] = len(taped_inputs)
|
||||
args[i] = v
|
||||
taped_inputs.append(v)
|
||||
ctx = FunctionContext()
|
||||
ori_res = self.forward(ctx, *args)
|
||||
# ori_res = self.execute(*args)
|
||||
if not isinstance(ori_res, Sequence):
|
||||
res = [ori_res]
|
||||
else:
|
||||
res = list(ori_res)
|
||||
output_mask = [-1] * len(res)
|
||||
for i,v in enumerate(res):
|
||||
if isinstance(v, Var):
|
||||
v = v.tape()
|
||||
output_mask[i] = len(taped_outputs)
|
||||
res[i] = v
|
||||
taped_outputs.append(v)
|
||||
ctx.input_mask = input_mask
|
||||
ctx.output_mask = output_mask
|
||||
# tape output and input together so
|
||||
# backward treat them as one operator
|
||||
jt.tape_together(taped_inputs, taped_outputs,
|
||||
lambda *args: self._grad(ctx, self, *args))
|
||||
if isinstance(ori_res, Sequence):
|
||||
return res
|
||||
else:
|
||||
return res[0]
|
||||
|
||||
@staticmethod
|
||||
def _grad(ctx, func, *args):
|
||||
new_args = ( (args[i] if i>=0 else None) for i in ctx.output_mask )
|
||||
ret = func.backward(ctx, *new_args)
|
||||
if not isinstance(ret, Sequence):
|
||||
ret = (ret,)
|
||||
new_ret = []
|
||||
for i, r in enumerate(ret):
|
||||
j = ctx.input_mask[i]
|
||||
if j<0:
|
||||
# -2 in input_mask represents it is stop_grad
|
||||
assert r is None or j==-2, f"{type(self)}'s {i}-th returned grad should be None, "\
|
||||
"because the input value is not jittor variable."
|
||||
else:
|
||||
new_ret.append(r)
|
||||
return new_ret
|
||||
|
||||
def dfs(self, parents, k, callback, callback_leave=None):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def apply(cls, *args, **kw):
|
||||
func = cls()
|
||||
return func(*args, **kw)
|
|
@ -1,39 +0,0 @@
|
|||
import jittor as jt
|
||||
import jittor_utils
|
||||
import glob
|
||||
import os
|
||||
from jittor import pyjt_compiler
|
||||
import sys
|
||||
from jittor_utils import lock
|
||||
|
||||
|
||||
jtorch_path = os.path.dirname(__file__)
|
||||
cache_path = os.path.join(jt.compiler.cache_path, "jtorch")
|
||||
# os.makedirs(cache_path, exist_ok=True)
|
||||
os.makedirs(os.path.join(cache_path, "gen"), exist_ok=True)
|
||||
|
||||
with lock.lock_scope():
|
||||
pyjt_gen_src = pyjt_compiler.compile(cache_path, jtorch_path)
|
||||
|
||||
ext_args = 'c[cu]' if jt.has_cuda else 'cc'
|
||||
files = glob.glob(jtorch_path+"/src/**/*."+ext_args, recursive=True)
|
||||
files += pyjt_gen_src
|
||||
cc_flags = " -I\""+os.path.join(jtorch_path, "src")+"\" "
|
||||
if os.environ.get("use_data_o", "1") == "1":
|
||||
files += glob.glob(jtorch_path+"/src/**/*.o", recursive=True)
|
||||
files = [f for f in files if "__data__" not in f]
|
||||
|
||||
|
||||
with lock.lock_scope():
|
||||
jt.compiler.compile(
|
||||
jt.compiler.cc_path,
|
||||
jt.compiler.cc_flags+jt.compiler.opt_flags+ cc_flags,
|
||||
files,
|
||||
"jtorch_core"+jt.compiler.extension_suffix,
|
||||
obj_dirname="jtorch_objs")
|
||||
|
||||
|
||||
with jittor_utils.import_scope(jt.compiler.import_flags):
|
||||
import jtorch_core as core
|
||||
|
||||
jt.flags.th_mode = 1
|
|
@ -1,64 +0,0 @@
|
|||
import jittor as jt
|
||||
import jtorch
|
||||
|
||||
def is_available():
|
||||
return jt.has_cuda
|
||||
|
||||
def device_count():
|
||||
return int(jt.has_cuda)
|
||||
|
||||
def set_device(device=None):
|
||||
pass
|
||||
|
||||
def get_rng_state(device=None):
|
||||
pass
|
||||
|
||||
def current_device():
|
||||
return jtorch.device("cuda")
|
||||
|
||||
def mem_get_info(i):
|
||||
return ("75GB",)
|
||||
|
||||
|
||||
class Generator:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def set_state(self, state):
|
||||
self.state = state
|
||||
|
||||
default_generators = [Generator()]
|
||||
_lazy_call = lambda func: func()
|
||||
device = None
|
||||
|
||||
LongTensor = jt.int64
|
||||
FloatTensor = jt.float
|
||||
HalfTensor = jt.float16
|
||||
BoolTensor = jt.bool
|
||||
|
||||
manual_seed = jt.set_global_seed
|
||||
manual_seed_all = jt.set_global_seed
|
||||
|
||||
def synchronize():
|
||||
jt.sync_all(True)
|
||||
|
||||
class Event:
|
||||
pass
|
||||
|
||||
class Stream:
|
||||
pass
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .gradscaler import GradScaler
|
||||
|
||||
class autocast:
|
||||
def __init__(self,**kwargs):
|
||||
pass
|
||||
|
||||
def __enter__(self,):
|
||||
pass
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any):
|
||||
pass
|
||||
|
|
@ -1,53 +0,0 @@
|
|||
import datetime
|
||||
from enum import Enum
|
||||
import jittor as jt
|
||||
|
||||
|
||||
class DistributedDataParallel:
|
||||
def __new__(cls, model):
|
||||
return model
|
||||
|
||||
def is_initialized():
|
||||
return True
|
||||
|
||||
def get_rank(group=None):
|
||||
return 0
|
||||
|
||||
def get_world_size(group=None):
|
||||
return 1
|
||||
|
||||
def get_backend(group=None):
|
||||
return "nccl"
|
||||
|
||||
def new_group(ranks=None, timeout=datetime.timedelta(seconds=1800), backend=None, pg_options=None):
|
||||
return 1
|
||||
|
||||
def barrier():
|
||||
pass
|
||||
|
||||
def is_available():
|
||||
return True
|
||||
|
||||
def is_built():
|
||||
return True
|
||||
|
||||
class ReduceOp:
|
||||
SUM = 0
|
||||
|
||||
class GroupMember:
|
||||
WORLD = 0
|
||||
|
||||
class ProcessGroup:
|
||||
pass
|
||||
|
||||
class Join:
|
||||
pass
|
||||
|
||||
dist_backend = Enum("dist_backend", ("GLOO", "MPI", "NCCL"))
|
||||
_backend = dist_backend.NCCL
|
||||
|
||||
def is_mpi_available():
|
||||
return jt.in_mpi
|
||||
|
||||
def DistributedDataParallel(model, *args, **kw):
|
||||
return model
|
|
@ -1,15 +0,0 @@
|
|||
import jittor as jt
|
||||
|
||||
class RelaxedBernoulli:
|
||||
def __init__(self, temperature, probs=None, logits=None):
|
||||
self.temperature = temperature
|
||||
self.probs = probs
|
||||
self.logits = logits
|
||||
|
||||
def rsample(self):
|
||||
noise = jt.rand_like(self.logits)
|
||||
eps = 1e-20
|
||||
noise = jt.clamp(noise, eps, 1.0 - eps)
|
||||
logit_noise = jt.log(noise) - jt.log(1 - noise)
|
||||
sample = (self.logits + logit_noise) / self.temperature
|
||||
return jt.sigmoid(sample)
|
|
@ -1,5 +0,0 @@
|
|||
#TODO: Implement FFT and IFFT
|
||||
fftn = None
|
||||
fftshift = None
|
||||
ifftn = None
|
||||
ifftshift = None
|
|
@ -1,2 +0,0 @@
|
|||
class Proxy:
|
||||
pass
|
|
@ -1,519 +0,0 @@
|
|||
from collections import defaultdict, abc
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Tuple, cast
|
||||
import inspect
|
||||
import warnings
|
||||
|
||||
import jittor as jt
|
||||
# import torch
|
||||
|
||||
def _refresh_per_optimizer_state():
|
||||
return {}
|
||||
|
||||
|
||||
class GradScaler:
|
||||
_scale: Optional[jt.Var]
|
||||
_grows_tracker: Optional[jt.Var]
|
||||
_per_optimizer_states: Dict[int, Dict[str, Any]]
|
||||
"""
|
||||
An instance ``scaler`` of :class:`GradScaler` helps perform the steps of gradient scaling
|
||||
conveniently.
|
||||
|
||||
* ``scaler.scale(loss)`` multiplies a given loss by ``scaler``'s current scale factor.
|
||||
* ``scaler.step(optimizer)`` safely unscales gradients and calls ``optimizer.step()``.
|
||||
* ``scaler.update()`` updates ``scaler``'s scale factor.
|
||||
|
||||
Example::
|
||||
|
||||
# Creates a GradScaler once at the beginning of training.
|
||||
scaler = GradScaler()
|
||||
|
||||
for epoch in epochs:
|
||||
for input, target in data:
|
||||
optimizer.zero_grad()
|
||||
output = model(input)
|
||||
loss = loss_fn(output, target)
|
||||
|
||||
# Scales loss. Calls backward() on scaled loss to create scaled gradients.
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
# scaler.step() first unscales gradients of the optimizer's params.
|
||||
# If gradients don't contain infs/NaNs, optimizer.step() is then called,
|
||||
# otherwise, optimizer.step() is skipped.
|
||||
scaler.step(optimizer)
|
||||
|
||||
# Updates the scale for next iteration.
|
||||
scaler.update()
|
||||
|
||||
See the :ref:`Automatic Mixed Precision examples<amp-examples>` for usage
|
||||
(along with autocasting) in more complex cases like gradient clipping, gradient accumulation, gradient penalty,
|
||||
and multiple losses/optimizers.
|
||||
|
||||
``scaler`` dynamically estimates the scale factor each iteration. To minimize gradient underflow,
|
||||
a large scale factor should be used. However, ``float16`` values can "overflow" (become inf or NaN) if
|
||||
the scale factor is too large. Therefore, the optimal scale factor is the largest factor that can be used
|
||||
without incurring inf or NaN gradient values.
|
||||
``scaler`` approximates the optimal scale factor over time by checking the gradients for infs and NaNs during every
|
||||
``scaler.step(optimizer)`` (or optional separate ``scaler.unscale_(optimizer)``, see :meth:`unscale_`).
|
||||
|
||||
* If infs/NaNs are found, ``scaler.step(optimizer)`` skips the underlying ``optimizer.step()`` (so the params
|
||||
themselves remain uncorrupted) and ``update()`` multiplies the scale by ``backoff_factor``.
|
||||
|
||||
* If no infs/NaNs are found, ``scaler.step(optimizer)`` runs the underlying ``optimizer.step()`` as usual.
|
||||
If ``growth_interval`` unskipped iterations occur consecutively, ``update()`` multiplies the scale by
|
||||
``growth_factor``.
|
||||
|
||||
The scale factor often causes infs/NaNs to appear in gradients for the first few iterations as its
|
||||
value calibrates. ``scaler.step`` will skip the underlying ``optimizer.step()`` for these
|
||||
iterations. After that, step skipping should occur rarely (once every few hundred or thousand iterations).
|
||||
|
||||
Args:
|
||||
init_scale (float, optional, default=2.**16): Initial scale factor.
|
||||
growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during
|
||||
:meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations.
|
||||
backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during
|
||||
:meth:`update` if inf/NaN gradients occur in an iteration.
|
||||
growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients
|
||||
that must occur for the scale to be multiplied by ``growth_factor``.
|
||||
enabled (bool, optional): If ``False``, disables gradient scaling. :meth:`step` simply
|
||||
invokes the underlying ``optimizer.step()``, and other methods become no-ops.
|
||||
Default: ``True``
|
||||
"""
|
||||
def __init__(self,
|
||||
init_scale=2.**16,
|
||||
growth_factor=2.0,
|
||||
backoff_factor=0.5,
|
||||
growth_interval=2000,
|
||||
enabled=True):
|
||||
self._enabled = enabled
|
||||
|
||||
if self._enabled:
|
||||
assert growth_factor > 1.0, "The growth factor must be > 1.0."
|
||||
assert backoff_factor < 1.0, "The backoff factor must be < 1.0."
|
||||
|
||||
self._init_scale = init_scale
|
||||
# self._scale will be lazily initialized during the first call to scale()
|
||||
self._scale = None
|
||||
self._growth_factor = growth_factor
|
||||
self._backoff_factor = backoff_factor
|
||||
self._growth_interval = growth_interval
|
||||
self._init_growth_tracker = 0
|
||||
# self._growth_tracker will be lazily initialized during the first call to scale()
|
||||
self._growth_tracker = None
|
||||
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
|
||||
|
||||
def _check_scale_growth_tracker(self, funcname) -> Tuple[jt.Var, jt.Var]:
|
||||
fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration."
|
||||
assert self._scale is not None, "Attempted {} but _scale is None. ".format(funcname) + fix
|
||||
assert self._growth_tracker is not None, "Attempted {} but _growth_tracker is None. ".format(funcname) + fix
|
||||
return (self._scale, self._growth_tracker)
|
||||
|
||||
def _lazy_init_scale_growth_tracker(self):
|
||||
assert self._growth_tracker is None, "_growth_tracker initialized before _scale"
|
||||
self._scale = self._init_scale
|
||||
self._growth_tracker = self._init_growth_tracker
|
||||
|
||||
def scale(self, outputs):
|
||||
"""
|
||||
Multiplies ('scales') a tensor or list of tensors by the scale factor.
|
||||
|
||||
Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned
|
||||
unmodified.
|
||||
|
||||
Args:
|
||||
outputs (Tensor or iterable of Tensors): Outputs to scale.
|
||||
"""
|
||||
if not self._enabled:
|
||||
return outputs
|
||||
|
||||
|
||||
# Short-circuit for the common case.
|
||||
if isinstance(outputs, jt.Var):
|
||||
assert jt.flags.use_cuda == 1
|
||||
if self._scale is None:
|
||||
self._lazy_init_scale_growth_tracker()
|
||||
assert self._scale is not None
|
||||
return outputs * self._scale
|
||||
|
||||
def apply_scale(val):
|
||||
if isinstance(val, jt.Var):
|
||||
assert jt.flags.use_cuda == 1
|
||||
if self._scale is None:
|
||||
self._lazy_init_scale_growth_tracker()
|
||||
assert self._scale is not None
|
||||
return val * self._scale
|
||||
elif isinstance(val, abc.Iterable):
|
||||
iterable = map(apply_scale, val)
|
||||
if isinstance(val, (list, tuple)):
|
||||
return type(val)(iterable)
|
||||
else:
|
||||
return iterable
|
||||
else:
|
||||
raise ValueError("outputs must be a Tensor or an iterable of Tensors")
|
||||
|
||||
return apply_scale(outputs)
|
||||
|
||||
def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16):
|
||||
with jt.no_grad():
|
||||
optimizer.pre_step()
|
||||
for group in optimizer.param_groups:
|
||||
for to_unscale in group["grads"]:
|
||||
if to_unscale is None or isinstance(to_unscale,(int,float)):
|
||||
continue
|
||||
if (not allow_fp16) and str(to_unscale.dtype) == "float16":
|
||||
raise ValueError("Attempting to unscale FP16 gradients.")
|
||||
|
||||
if not (to_unscale.isinf().any()):
|
||||
if inv_scale != 1.0:
|
||||
to_unscale.update(to_unscale*inv_scale)
|
||||
else:
|
||||
found_inf = 1.0
|
||||
|
||||
return found_inf
|
||||
|
||||
def unscale_(self, optimizer):
|
||||
"""
|
||||
Divides ("unscales") the optimizer's gradient tensors by the scale factor.
|
||||
|
||||
:meth:`unscale_` is optional, serving cases where you need to
|
||||
:ref:`modify or inspect gradients<working-with-unscaled-gradients>`
|
||||
between the backward pass(es) and :meth:`step`.
|
||||
If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`.
|
||||
|
||||
Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients::
|
||||
|
||||
...
|
||||
scaler.scale(loss).backward()
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
|
||||
Args:
|
||||
optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled.
|
||||
|
||||
.. note::
|
||||
:meth:`unscale_` does not incur a CPU-GPU sync.
|
||||
|
||||
.. warning::
|
||||
:meth:`unscale_` should only be called once per optimizer per :meth:`step` call,
|
||||
and only after all gradients for that optimizer's assigned parameters have been accumulated.
|
||||
Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError.
|
||||
|
||||
.. warning::
|
||||
:meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute.
|
||||
"""
|
||||
if not self._enabled:
|
||||
return
|
||||
|
||||
self._check_scale_growth_tracker("unscale_")
|
||||
|
||||
optimizer_state = self._per_optimizer_states[id(optimizer)]
|
||||
|
||||
if hasattr(optimizer,"get_find_inf"):
|
||||
return
|
||||
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
|
||||
assert self._scale is not None
|
||||
inv_scale = 1.0 / self._scale
|
||||
found_inf = 0.0
|
||||
optimizer_state["found_inf_per_device"] = self._unscale_grads_(optimizer, inv_scale, found_inf, False)
|
||||
|
||||
|
||||
def step(self, optimizer, *args, **kwargs):
|
||||
"""
|
||||
:meth:`step` carries out the following two operations:
|
||||
|
||||
1. Internally invokes ``unscale_(optimizer)`` (unless :meth:`unscale_` was explicitly called for ``optimizer``
|
||||
earlier in the iteration). As part of the :meth:`unscale_`, gradients are checked for infs/NaNs.
|
||||
2. If no inf/NaN gradients are found, invokes ``optimizer.step()`` using the unscaled
|
||||
gradients. Otherwise, ``optimizer.step()`` is skipped to avoid corrupting the params.
|
||||
|
||||
``*args`` and ``**kwargs`` are forwarded to ``optimizer.step()``.
|
||||
|
||||
Returns the return value of ``optimizer.step(*args, **kwargs)``.
|
||||
|
||||
Args:
|
||||
optimizer (torch.optim.Optimizer): Optimizer that applies the gradients.
|
||||
args: Any arguments.
|
||||
kwargs: Any keyword arguments.
|
||||
|
||||
.. warning::
|
||||
Closure use is not currently supported.
|
||||
"""
|
||||
if (not self._enabled):
|
||||
return optimizer.step(*args, **kwargs)
|
||||
|
||||
if "closure" in kwargs:
|
||||
raise RuntimeError("Closure use is not currently supported if GradScaler is enabled.")
|
||||
|
||||
self._check_scale_growth_tracker("step")
|
||||
|
||||
optimizer_state = self._per_optimizer_states[id(optimizer)]
|
||||
retval = None
|
||||
|
||||
if (hasattr(optimizer, "_step_supports_amp_scaling") and optimizer._step_supports_amp_scaling):
|
||||
# This optimizer has customized scale-handling logic, so we can call optimizer.step() directly.
|
||||
# The contract with custom optimizers is that their step() should accept an additional,
|
||||
# optional grad_scaler kwarg. We append self to the kwargs so the custom optimizer has full information:
|
||||
# it can query its own state, invoke unscale_ on itself, etc
|
||||
# The contract above is being deprecated to avoid introducing `grad_scaler: GradScaler` argument
|
||||
# to `Optimizer.step`. The new behavior is going to add two Tensor attributes of `grad_scale`
|
||||
# and `found_inf` to the passed optimizer so that the optimizer can utilize those
|
||||
# to skip the parameter updates or unscale gradients before updating parameters in
|
||||
# the fused kernel, e.g. `FusedAdamMathFunctor`.
|
||||
# In this behavior, `GradScaler._check_inf_per_device` is called if `OptState.READY`,
|
||||
# while the method is expected to be called by users side, i.e. their optimizers.
|
||||
kwargs_ = kwargs
|
||||
has_grad_scaler_kwarg = "grad_scaler" in inspect.signature(optimizer.step).parameters
|
||||
if has_grad_scaler_kwarg:
|
||||
warnings.warn(
|
||||
"GradScaler is going to stop passing itself as a keyword argument to the passed "
|
||||
"optimizer. In the near future GradScaler registers `grad_scale: Tensor` and "
|
||||
"`found_inf: Tensor` to the passed optimizer and let the optimizer use them directly.",
|
||||
FutureWarning)
|
||||
kwargs_.update({"grad_scaler": self})
|
||||
else:
|
||||
if optimizer_state["stage"] is OptState.READY:
|
||||
self._check_inf_per_device(optimizer)
|
||||
scaler = self._get_scale_async()
|
||||
found_inf = cast(
|
||||
jt.Var,
|
||||
sum([
|
||||
t for t in optimizer_state["found_inf_per_device"].values()
|
||||
])
|
||||
)
|
||||
optimizer.grad_scale = None if optimizer_state["stage"] == OptState.UNSCALED else scaler
|
||||
optimizer.found_inf = found_inf
|
||||
retval = optimizer.step(*args, **kwargs_)
|
||||
optimizer_state["stage"] = OptState.STEPPED
|
||||
if not has_grad_scaler_kwarg:
|
||||
del optimizer.grad_scale
|
||||
del optimizer.found_inf
|
||||
return retval
|
||||
|
||||
if hasattr(optimizer,"get_find_inf"):
|
||||
optimizer.set_grad_scale(self._scale)
|
||||
optimizer.step()
|
||||
optimizer_state["found_inf_per_device"] = optimizer.get_find_inf()
|
||||
return
|
||||
|
||||
retval = None
|
||||
if not optimizer_state["found_inf_per_device"]:
|
||||
retval = optimizer.step(*args, **kwargs)
|
||||
else:
|
||||
optimizer.post_step()
|
||||
|
||||
return retval
|
||||
|
||||
|
||||
def update(self, new_scale=None):
|
||||
"""
|
||||
Updates the scale factor.
|
||||
|
||||
If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
|
||||
to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
|
||||
the scale is multiplied by ``growth_factor`` to increase it.
|
||||
|
||||
Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
|
||||
used directly, it's used to fill GradScaler's internal scale tensor. So if
|
||||
``new_scale`` was a tensor, later in-place changes to that tensor will not further
|
||||
affect the scale GradScaler uses internally.)
|
||||
|
||||
Args:
|
||||
new_scale (float or :class:`torch.cuda.FloatTensor`, optional, default=None): New scale factor.
|
||||
|
||||
.. warning::
|
||||
:meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
|
||||
been invoked for all optimizers used this iteration.
|
||||
"""
|
||||
if not self._enabled:
|
||||
return
|
||||
|
||||
_scale, _growth_tracker = self._check_scale_growth_tracker("update")
|
||||
|
||||
if new_scale is not None:
|
||||
# Accept a new user-defined scale.
|
||||
if isinstance(new_scale, float):
|
||||
self._scale.fill_(new_scale) # type: ignore[union-attr]
|
||||
else:
|
||||
reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False."
|
||||
assert isinstance(new_scale, torch.cuda.FloatTensor), reason # type: ignore[attr-defined]
|
||||
assert new_scale.numel() == 1, reason
|
||||
assert new_scale.requires_grad is False, reason
|
||||
self._scale.copy_(new_scale) # type: ignore[union-attr]
|
||||
else:
|
||||
# Consume shared inf/nan data collected from optimizers to update the scale.
|
||||
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
|
||||
found_infs = [state["found_inf_per_device"]
|
||||
for state in self._per_optimizer_states.values()
|
||||
]
|
||||
|
||||
assert len(found_infs) > 0, "No inf checks were recorded prior to update."
|
||||
|
||||
found_inf_combined = found_infs[0]
|
||||
if len(found_infs) > 1:
|
||||
for i in range(1, len(found_infs)):
|
||||
found_inf_combined += found_infs[i]
|
||||
|
||||
|
||||
current_scale = _scale
|
||||
if found_inf_combined:
|
||||
current_scale *=self._backoff_factor
|
||||
_growth_tracker = 0
|
||||
else:
|
||||
successful = _growth_tracker+1
|
||||
if successful == self._growth_interval:
|
||||
new_scale = current_scale*self._growth_factor
|
||||
if new_scale < 1e9:
|
||||
current_scale = new_scale
|
||||
_growth_tracker = 0
|
||||
else:
|
||||
_growth_tracker = successful
|
||||
|
||||
self._scale, self._growth_tracker = current_scale,_growth_tracker
|
||||
|
||||
# To prepare for next iteration, clear the data collected from optimizers this iteration.
|
||||
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
|
||||
|
||||
def _get_scale_async(self):
|
||||
return self._scale
|
||||
|
||||
def get_scale(self):
|
||||
"""
|
||||
Returns a Python float containing the current scale, or 1.0 if scaling is disabled.
|
||||
|
||||
.. warning::
|
||||
:meth:`get_scale` incurs a CPU-GPU sync.
|
||||
"""
|
||||
if self._enabled:
|
||||
return self._init_scale if self._scale is None else self._get_scale_async()
|
||||
else:
|
||||
return 1.0
|
||||
|
||||
def get_growth_factor(self):
|
||||
r"""
|
||||
Returns a Python float containing the scale growth factor.
|
||||
"""
|
||||
return self._growth_factor
|
||||
|
||||
def set_growth_factor(self, new_factor):
|
||||
r"""
|
||||
Args:
|
||||
new_scale (float): Value to use as the new scale growth factor.
|
||||
"""
|
||||
self._growth_factor = new_factor
|
||||
|
||||
def get_backoff_factor(self):
|
||||
r"""
|
||||
Returns a Python float containing the scale backoff factor.
|
||||
"""
|
||||
return self._backoff_factor
|
||||
|
||||
def set_backoff_factor(self, new_factor):
|
||||
r"""
|
||||
Args:
|
||||
new_scale (float): Value to use as the new scale backoff factor.
|
||||
"""
|
||||
self._backoff_factor = new_factor
|
||||
|
||||
def get_growth_interval(self):
|
||||
r"""
|
||||
Returns a Python int containing the growth interval.
|
||||
"""
|
||||
return self._growth_interval
|
||||
|
||||
def set_growth_interval(self, new_interval):
|
||||
r"""
|
||||
Args:
|
||||
new_interval (int): Value to use as the new growth interval.
|
||||
"""
|
||||
self._growth_interval = new_interval
|
||||
|
||||
def _get_growth_tracker(self):
|
||||
if self._enabled:
|
||||
return self._init_growth_tracker if self._growth_tracker is None else self._growth_tracker.item()
|
||||
else:
|
||||
return 0
|
||||
|
||||
def is_enabled(self):
|
||||
r"""
|
||||
Returns a bool indicating whether this instance is enabled.
|
||||
"""
|
||||
return self._enabled
|
||||
|
||||
def state_dict(self):
|
||||
r"""
|
||||
Returns the state of the scaler as a :class:`dict`. It contains five entries:
|
||||
|
||||
* ``"scale"`` - a Python float containing the current scale
|
||||
* ``"growth_factor"`` - a Python float containing the current growth factor
|
||||
* ``"backoff_factor"`` - a Python float containing the current backoff factor
|
||||
* ``"growth_interval"`` - a Python int containing the current growth interval
|
||||
* ``"_growth_tracker"`` - a Python int containing the number of recent consecutive unskipped steps.
|
||||
|
||||
If this instance is not enabled, returns an empty dict.
|
||||
|
||||
.. note::
|
||||
If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict`
|
||||
should be called after :meth:`update`.
|
||||
"""
|
||||
return {"scale": self.get_scale(),
|
||||
"growth_factor": self._growth_factor,
|
||||
"backoff_factor": self._backoff_factor,
|
||||
"growth_interval": self._growth_interval,
|
||||
"_growth_tracker": self._get_growth_tracker()} if self._enabled else {}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
r"""
|
||||
Loads the scaler state. If this instance is disabled, :meth:`load_state_dict` is a no-op.
|
||||
|
||||
Args:
|
||||
state_dict(dict): scaler state. Should be an object returned from a call to :meth:`state_dict`.
|
||||
"""
|
||||
if not self._enabled:
|
||||
return
|
||||
|
||||
if len(state_dict) == 0:
|
||||
raise RuntimeError("The source state dict is empty, possibly because it was saved "
|
||||
"from a disabled instance of GradScaler.")
|
||||
|
||||
self._init_scale = state_dict["scale"]
|
||||
if self._scale is not None:
|
||||
self._scale.fill_(state_dict["scale"])
|
||||
self._growth_factor = state_dict["growth_factor"]
|
||||
self._backoff_factor = state_dict["backoff_factor"]
|
||||
self._growth_interval = state_dict["growth_interval"]
|
||||
self._init_growth_tracker = state_dict["_growth_tracker"]
|
||||
if self._growth_tracker is not None:
|
||||
self._growth_tracker.fill_(state_dict["_growth_tracker"])
|
||||
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
if self._enabled:
|
||||
assert len(self._per_optimizer_states) == 0, "A GradScaler instance may only be pickled at the beginning "\
|
||||
"of an iteration, or at the end after scaler.update()."
|
||||
# Pickling _scale and _growth_tracker Tensors directly triggers
|
||||
# "warnings.warn("pickle support for Storage will be removed in 1.5..."
|
||||
# so instead, we set the unpickled instance up to reinitialize them lazily.
|
||||
state['_init_scale'] = self.get_scale()
|
||||
state['_init_growth_tracker'] = self._get_growth_tracker()
|
||||
state['_scale'] = None
|
||||
state['_growth_tracker'] = None
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__dict__.update(state)
|
||||
|
||||
def _check_inf_per_device(self, optimizer):
|
||||
_scale, _ = self._check_scale_growth_tracker("_check_inf_per_device")
|
||||
|
||||
dummy_inv_scale = 1.0
|
||||
found_inf = 0.0
|
||||
|
||||
self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = \
|
||||
self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True)
|
||||
|
||||
return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]
|
||||
|
||||
def _found_inf_per_device(self, optimizer):
|
||||
return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]
|
|
@ -1,556 +0,0 @@
|
|||
from collections import defaultdict, abc
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Tuple, cast
|
||||
import inspect
|
||||
import warnings
|
||||
|
||||
import jittor as jt
|
||||
# import torch
|
||||
|
||||
|
||||
__all__ = ["OptState", "GradScaler"]
|
||||
|
||||
|
||||
# Defines default_factory for GradScaler's _per_optimizer_states defaultdict,
|
||||
# as well as associated "enum" values. Prefers defining these at top level because
|
||||
# - Lambdas can't be pickled, so we don't want to supply a lambda as the factory.
|
||||
# - Defining READY, UNSCALED, STEPPED and _refresh_per_optimizer_state within GradScaler
|
||||
# causes a circular reference, which we'd rather avoid.
|
||||
class OptState(Enum):
|
||||
READY = 0
|
||||
UNSCALED = 1
|
||||
STEPPED = 2
|
||||
|
||||
|
||||
def _refresh_per_optimizer_state():
|
||||
return {"stage": OptState.READY, "found_inf_per_device": {}}
|
||||
|
||||
|
||||
class GradScaler:
|
||||
_scale: Optional[jt.Var]
|
||||
_grows_tracker: Optional[jt.Var]
|
||||
_per_optimizer_states: Dict[int, Dict[str, Any]]
|
||||
"""
|
||||
An instance ``scaler`` of :class:`GradScaler` helps perform the steps of gradient scaling
|
||||
conveniently.
|
||||
|
||||
* ``scaler.scale(loss)`` multiplies a given loss by ``scaler``'s current scale factor.
|
||||
* ``scaler.step(optimizer)`` safely unscales gradients and calls ``optimizer.step()``.
|
||||
* ``scaler.update()`` updates ``scaler``'s scale factor.
|
||||
|
||||
Example::
|
||||
|
||||
# Creates a GradScaler once at the beginning of training.
|
||||
scaler = GradScaler()
|
||||
|
||||
for epoch in epochs:
|
||||
for input, target in data:
|
||||
optimizer.zero_grad()
|
||||
output = model(input)
|
||||
loss = loss_fn(output, target)
|
||||
|
||||
# Scales loss. Calls backward() on scaled loss to create scaled gradients.
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
# scaler.step() first unscales gradients of the optimizer's params.
|
||||
# If gradients don't contain infs/NaNs, optimizer.step() is then called,
|
||||
# otherwise, optimizer.step() is skipped.
|
||||
scaler.step(optimizer)
|
||||
|
||||
# Updates the scale for next iteration.
|
||||
scaler.update()
|
||||
|
||||
See the :ref:`Automatic Mixed Precision examples<amp-examples>` for usage
|
||||
(along with autocasting) in more complex cases like gradient clipping, gradient accumulation, gradient penalty,
|
||||
and multiple losses/optimizers.
|
||||
|
||||
``scaler`` dynamically estimates the scale factor each iteration. To minimize gradient underflow,
|
||||
a large scale factor should be used. However, ``float16`` values can "overflow" (become inf or NaN) if
|
||||
the scale factor is too large. Therefore, the optimal scale factor is the largest factor that can be used
|
||||
without incurring inf or NaN gradient values.
|
||||
``scaler`` approximates the optimal scale factor over time by checking the gradients for infs and NaNs during every
|
||||
``scaler.step(optimizer)`` (or optional separate ``scaler.unscale_(optimizer)``, see :meth:`unscale_`).
|
||||
|
||||
* If infs/NaNs are found, ``scaler.step(optimizer)`` skips the underlying ``optimizer.step()`` (so the params
|
||||
themselves remain uncorrupted) and ``update()`` multiplies the scale by ``backoff_factor``.
|
||||
|
||||
* If no infs/NaNs are found, ``scaler.step(optimizer)`` runs the underlying ``optimizer.step()`` as usual.
|
||||
If ``growth_interval`` unskipped iterations occur consecutively, ``update()`` multiplies the scale by
|
||||
``growth_factor``.
|
||||
|
||||
The scale factor often causes infs/NaNs to appear in gradients for the first few iterations as its
|
||||
value calibrates. ``scaler.step`` will skip the underlying ``optimizer.step()`` for these
|
||||
iterations. After that, step skipping should occur rarely (once every few hundred or thousand iterations).
|
||||
|
||||
Args:
|
||||
init_scale (float, optional, default=2.**16): Initial scale factor.
|
||||
growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during
|
||||
:meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations.
|
||||
backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during
|
||||
:meth:`update` if inf/NaN gradients occur in an iteration.
|
||||
growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients
|
||||
that must occur for the scale to be multiplied by ``growth_factor``.
|
||||
enabled (bool, optional): If ``False``, disables gradient scaling. :meth:`step` simply
|
||||
invokes the underlying ``optimizer.step()``, and other methods become no-ops.
|
||||
Default: ``True``
|
||||
"""
|
||||
def __init__(self,
|
||||
init_scale=2.**16,
|
||||
growth_factor=2.0,
|
||||
backoff_factor=0.5,
|
||||
growth_interval=2000,
|
||||
enabled=True):
|
||||
self._enabled = enabled
|
||||
|
||||
if self._enabled:
|
||||
assert growth_factor > 1.0, "The growth factor must be > 1.0."
|
||||
assert backoff_factor < 1.0, "The backoff factor must be < 1.0."
|
||||
|
||||
self._init_scale = init_scale
|
||||
# self._scale will be lazily initialized during the first call to scale()
|
||||
self._scale = None
|
||||
self._growth_factor = growth_factor
|
||||
self._backoff_factor = backoff_factor
|
||||
self._growth_interval = growth_interval
|
||||
self._init_growth_tracker = 0
|
||||
# self._growth_tracker will be lazily initialized during the first call to scale()
|
||||
self._growth_tracker = None
|
||||
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
|
||||
|
||||
def _check_scale_growth_tracker(self, funcname) -> Tuple[jt.Var, jt.Var]:
|
||||
fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration."
|
||||
assert self._scale is not None, "Attempted {} but _scale is None. ".format(funcname) + fix
|
||||
assert self._growth_tracker is not None, "Attempted {} but _growth_tracker is None. ".format(funcname) + fix
|
||||
return (self._scale, self._growth_tracker)
|
||||
|
||||
def _lazy_init_scale_growth_tracker(self):
|
||||
assert self._growth_tracker is None, "_growth_tracker initialized before _scale"
|
||||
self._scale = self._init_scale
|
||||
self._growth_tracker = self._init_growth_tracker
|
||||
|
||||
def scale(self, outputs):
|
||||
"""
|
||||
Multiplies ('scales') a tensor or list of tensors by the scale factor.
|
||||
|
||||
Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned
|
||||
unmodified.
|
||||
|
||||
Args:
|
||||
outputs (Tensor or iterable of Tensors): Outputs to scale.
|
||||
"""
|
||||
print("scale")
|
||||
if not self._enabled:
|
||||
return outputs
|
||||
|
||||
|
||||
# Short-circuit for the common case.
|
||||
if isinstance(outputs, jt.Var):
|
||||
assert jt.flags.use_cuda == 1
|
||||
if self._scale is None:
|
||||
self._lazy_init_scale_growth_tracker()
|
||||
assert self._scale is not None
|
||||
return outputs * self._scale
|
||||
|
||||
def apply_scale(val):
|
||||
if isinstance(val, jt.Var):
|
||||
assert jt.flags.use_cuda == 1
|
||||
if self._scale is None:
|
||||
self._lazy_init_scale_growth_tracker()
|
||||
assert self._scale is not None
|
||||
return val * self._scale
|
||||
elif isinstance(val, abc.Iterable):
|
||||
iterable = map(apply_scale, val)
|
||||
if isinstance(val, (list, tuple)):
|
||||
return type(val)(iterable)
|
||||
else:
|
||||
return iterable
|
||||
else:
|
||||
raise ValueError("outputs must be a Tensor or an iterable of Tensors")
|
||||
|
||||
return apply_scale(outputs)
|
||||
|
||||
def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16):
|
||||
|
||||
# To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype.
|
||||
# There could be hundreds of grads, so we'd like to iterate through them just once.
|
||||
# However, we don't know their devices or dtypes in advance.
|
||||
|
||||
# https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
|
||||
# Google says mypy struggles with defaultdicts type annotations.
|
||||
with jt.no_grad():
|
||||
optimizer.pre_step()
|
||||
for group in optimizer.param_groups:
|
||||
for to_unscale in group["grads"]:
|
||||
if to_unscale is None or isinstance(to_unscale,(int,float)):
|
||||
continue
|
||||
if (not allow_fp16) and str(to_unscale.dtype) == "float16":
|
||||
raise ValueError("Attempting to unscale FP16 gradients.")
|
||||
|
||||
if not (to_unscale.isinf().any()):
|
||||
if inv_scale != 1.0:
|
||||
to_unscale.update(to_unscale*inv_scale)
|
||||
else:
|
||||
found_inf = 1.0
|
||||
|
||||
return found_inf
|
||||
|
||||
def unscale_(self, optimizer):
|
||||
"""
|
||||
Divides ("unscales") the optimizer's gradient tensors by the scale factor.
|
||||
|
||||
:meth:`unscale_` is optional, serving cases where you need to
|
||||
:ref:`modify or inspect gradients<working-with-unscaled-gradients>`
|
||||
between the backward pass(es) and :meth:`step`.
|
||||
If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`.
|
||||
|
||||
Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients::
|
||||
|
||||
...
|
||||
scaler.scale(loss).backward()
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
|
||||
Args:
|
||||
optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled.
|
||||
|
||||
.. note::
|
||||
:meth:`unscale_` does not incur a CPU-GPU sync.
|
||||
|
||||
.. warning::
|
||||
:meth:`unscale_` should only be called once per optimizer per :meth:`step` call,
|
||||
and only after all gradients for that optimizer's assigned parameters have been accumulated.
|
||||
Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError.
|
||||
|
||||
.. warning::
|
||||
:meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute.
|
||||
"""
|
||||
if not self._enabled:
|
||||
return
|
||||
|
||||
self._check_scale_growth_tracker("unscale_")
|
||||
|
||||
optimizer_state = self._per_optimizer_states[id(optimizer)]
|
||||
|
||||
if optimizer_state["stage"] is OptState.UNSCALED:
|
||||
raise RuntimeError("unscale_() has already been called on this optimizer since the last update().")
|
||||
elif optimizer_state["stage"] is OptState.STEPPED:
|
||||
raise RuntimeError("unscale_() is being called after step().")
|
||||
|
||||
|
||||
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
|
||||
assert self._scale is not None
|
||||
inv_scale = 1.0 / self._scale
|
||||
found_inf = 0.0
|
||||
optimizer_state["found_inf_per_device"] = self._unscale_grads_(optimizer, inv_scale, found_inf, False)
|
||||
optimizer_state["stage"] = OptState.UNSCALED
|
||||
|
||||
def _maybe_opt_step(self, optimizer, optimizer_state, *args, **kwargs):
|
||||
retval = None
|
||||
if not optimizer_state["found_inf_per_device"]:
|
||||
retval = optimizer.step(*args, **kwargs)
|
||||
else:
|
||||
optimizer.post_step()
|
||||
|
||||
return retval
|
||||
|
||||
def step(self, optimizer, *args, **kwargs):
|
||||
"""
|
||||
:meth:`step` carries out the following two operations:
|
||||
|
||||
1. Internally invokes ``unscale_(optimizer)`` (unless :meth:`unscale_` was explicitly called for ``optimizer``
|
||||
earlier in the iteration). As part of the :meth:`unscale_`, gradients are checked for infs/NaNs.
|
||||
2. If no inf/NaN gradients are found, invokes ``optimizer.step()`` using the unscaled
|
||||
gradients. Otherwise, ``optimizer.step()`` is skipped to avoid corrupting the params.
|
||||
|
||||
``*args`` and ``**kwargs`` are forwarded to ``optimizer.step()``.
|
||||
|
||||
Returns the return value of ``optimizer.step(*args, **kwargs)``.
|
||||
|
||||
Args:
|
||||
optimizer (torch.optim.Optimizer): Optimizer that applies the gradients.
|
||||
args: Any arguments.
|
||||
kwargs: Any keyword arguments.
|
||||
|
||||
.. warning::
|
||||
Closure use is not currently supported.
|
||||
"""
|
||||
if (not self._enabled):
|
||||
return optimizer.step(*args, **kwargs)
|
||||
|
||||
if "closure" in kwargs:
|
||||
raise RuntimeError("Closure use is not currently supported if GradScaler is enabled.")
|
||||
|
||||
self._check_scale_growth_tracker("step")
|
||||
|
||||
optimizer_state = self._per_optimizer_states[id(optimizer)]
|
||||
|
||||
if optimizer_state["stage"] is OptState.STEPPED:
|
||||
raise RuntimeError("step() has already been called since the last update().")
|
||||
|
||||
retval = None
|
||||
|
||||
if (hasattr(optimizer, "_step_supports_amp_scaling") and optimizer._step_supports_amp_scaling):
|
||||
# This optimizer has customized scale-handling logic, so we can call optimizer.step() directly.
|
||||
# The contract with custom optimizers is that their step() should accept an additional,
|
||||
# optional grad_scaler kwarg. We append self to the kwargs so the custom optimizer has full information:
|
||||
# it can query its own state, invoke unscale_ on itself, etc
|
||||
# The contract above is being deprecated to avoid introducing `grad_scaler: GradScaler` argument
|
||||
# to `Optimizer.step`. The new behavior is going to add two Tensor attributes of `grad_scale`
|
||||
# and `found_inf` to the passed optimizer so that the optimizer can utilize those
|
||||
# to skip the parameter updates or unscale gradients before updating parameters in
|
||||
# the fused kernel, e.g. `FusedAdamMathFunctor`.
|
||||
# In this behavior, `GradScaler._check_inf_per_device` is called if `OptState.READY`,
|
||||
# while the method is expected to be called by users side, i.e. their optimizers.
|
||||
kwargs_ = kwargs
|
||||
has_grad_scaler_kwarg = "grad_scaler" in inspect.signature(optimizer.step).parameters
|
||||
if has_grad_scaler_kwarg:
|
||||
warnings.warn(
|
||||
"GradScaler is going to stop passing itself as a keyword argument to the passed "
|
||||
"optimizer. In the near future GradScaler registers `grad_scale: Tensor` and "
|
||||
"`found_inf: Tensor` to the passed optimizer and let the optimizer use them directly.",
|
||||
FutureWarning)
|
||||
kwargs_.update({"grad_scaler": self})
|
||||
else:
|
||||
if optimizer_state["stage"] is OptState.READY:
|
||||
self._check_inf_per_device(optimizer)
|
||||
scaler = self._get_scale_async()
|
||||
found_inf = cast(
|
||||
jt.Var,
|
||||
sum([
|
||||
t for t in optimizer_state["found_inf_per_device"].values()
|
||||
])
|
||||
)
|
||||
optimizer.grad_scale = None if optimizer_state["stage"] == OptState.UNSCALED else scaler
|
||||
optimizer.found_inf = found_inf
|
||||
retval = optimizer.step(*args, **kwargs_)
|
||||
optimizer_state["stage"] = OptState.STEPPED
|
||||
if not has_grad_scaler_kwarg:
|
||||
del optimizer.grad_scale
|
||||
del optimizer.found_inf
|
||||
return retval
|
||||
|
||||
|
||||
if optimizer_state["stage"] is OptState.READY:
|
||||
self.unscale_(optimizer)
|
||||
|
||||
assert "found_inf_per_device" in optimizer_state, "No inf checks were recorded for this optimizer."
|
||||
|
||||
retval = self._maybe_opt_step(optimizer, optimizer_state, *args, **kwargs)
|
||||
|
||||
optimizer_state["stage"] = OptState.STEPPED
|
||||
|
||||
return retval
|
||||
|
||||
def update(self, new_scale=None):
|
||||
"""
|
||||
Updates the scale factor.
|
||||
|
||||
If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
|
||||
to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
|
||||
the scale is multiplied by ``growth_factor`` to increase it.
|
||||
|
||||
Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
|
||||
used directly, it's used to fill GradScaler's internal scale tensor. So if
|
||||
``new_scale`` was a tensor, later in-place changes to that tensor will not further
|
||||
affect the scale GradScaler uses internally.)
|
||||
|
||||
Args:
|
||||
new_scale (float or :class:`torch.cuda.FloatTensor`, optional, default=None): New scale factor.
|
||||
|
||||
.. warning::
|
||||
:meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
|
||||
been invoked for all optimizers used this iteration.
|
||||
"""
|
||||
if not self._enabled:
|
||||
return
|
||||
|
||||
_scale, _growth_tracker = self._check_scale_growth_tracker("update")
|
||||
|
||||
if new_scale is not None:
|
||||
# Accept a new user-defined scale.
|
||||
if isinstance(new_scale, float):
|
||||
self._scale.fill_(new_scale) # type: ignore[union-attr]
|
||||
else:
|
||||
reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False."
|
||||
assert isinstance(new_scale, torch.cuda.FloatTensor), reason # type: ignore[attr-defined]
|
||||
assert new_scale.numel() == 1, reason
|
||||
assert new_scale.requires_grad is False, reason
|
||||
self._scale.copy_(new_scale) # type: ignore[union-attr]
|
||||
else:
|
||||
# Consume shared inf/nan data collected from optimizers to update the scale.
|
||||
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
|
||||
found_infs = [state["found_inf_per_device"]
|
||||
for state in self._per_optimizer_states.values()
|
||||
]
|
||||
|
||||
assert len(found_infs) > 0, "No inf checks were recorded prior to update."
|
||||
|
||||
found_inf_combined = found_infs[0]
|
||||
if len(found_infs) > 1:
|
||||
for i in range(1, len(found_infs)):
|
||||
found_inf_combined += found_infs[i]
|
||||
|
||||
|
||||
current_scale = _scale
|
||||
if found_inf_combined:
|
||||
current_scale *=self._backoff_factor
|
||||
_growth_tracker = 0
|
||||
else:
|
||||
successful = _growth_tracker+1
|
||||
if successful == self._growth_interval:
|
||||
new_scale = current_scale*self._growth_factor
|
||||
if new_scale < 1e9:
|
||||
current_scale = new_scale
|
||||
_growth_tracker = 0
|
||||
else:
|
||||
_growth_tracker = successful
|
||||
|
||||
self._scale, self._growth_tracker = current_scale,_growth_tracker
|
||||
|
||||
# To prepare for next iteration, clear the data collected from optimizers this iteration.
|
||||
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
|
||||
|
||||
def _get_scale_async(self):
|
||||
return self._scale
|
||||
|
||||
def get_scale(self):
|
||||
"""
|
||||
Returns a Python float containing the current scale, or 1.0 if scaling is disabled.
|
||||
|
||||
.. warning::
|
||||
:meth:`get_scale` incurs a CPU-GPU sync.
|
||||
"""
|
||||
if self._enabled:
|
||||
return self._init_scale if self._scale is None else self._get_scale_async()
|
||||
else:
|
||||
return 1.0
|
||||
|
||||
def get_growth_factor(self):
|
||||
r"""
|
||||
Returns a Python float containing the scale growth factor.
|
||||
"""
|
||||
return self._growth_factor
|
||||
|
||||
def set_growth_factor(self, new_factor):
|
||||
r"""
|
||||
Args:
|
||||
new_scale (float): Value to use as the new scale growth factor.
|
||||
"""
|
||||
self._growth_factor = new_factor
|
||||
|
||||
def get_backoff_factor(self):
|
||||
r"""
|
||||
Returns a Python float containing the scale backoff factor.
|
||||
"""
|
||||
return self._backoff_factor
|
||||
|
||||
def set_backoff_factor(self, new_factor):
|
||||
r"""
|
||||
Args:
|
||||
new_scale (float): Value to use as the new scale backoff factor.
|
||||
"""
|
||||
self._backoff_factor = new_factor
|
||||
|
||||
def get_growth_interval(self):
|
||||
r"""
|
||||
Returns a Python int containing the growth interval.
|
||||
"""
|
||||
return self._growth_interval
|
||||
|
||||
def set_growth_interval(self, new_interval):
|
||||
r"""
|
||||
Args:
|
||||
new_interval (int): Value to use as the new growth interval.
|
||||
"""
|
||||
self._growth_interval = new_interval
|
||||
|
||||
def _get_growth_tracker(self):
|
||||
if self._enabled:
|
||||
return self._init_growth_tracker if self._growth_tracker is None else self._growth_tracker.item()
|
||||
else:
|
||||
return 0
|
||||
|
||||
def is_enabled(self):
|
||||
r"""
|
||||
Returns a bool indicating whether this instance is enabled.
|
||||
"""
|
||||
return self._enabled
|
||||
|
||||
def state_dict(self):
|
||||
r"""
|
||||
Returns the state of the scaler as a :class:`dict`. It contains five entries:
|
||||
|
||||
* ``"scale"`` - a Python float containing the current scale
|
||||
* ``"growth_factor"`` - a Python float containing the current growth factor
|
||||
* ``"backoff_factor"`` - a Python float containing the current backoff factor
|
||||
* ``"growth_interval"`` - a Python int containing the current growth interval
|
||||
* ``"_growth_tracker"`` - a Python int containing the number of recent consecutive unskipped steps.
|
||||
|
||||
If this instance is not enabled, returns an empty dict.
|
||||
|
||||
.. note::
|
||||
If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict`
|
||||
should be called after :meth:`update`.
|
||||
"""
|
||||
return {"scale": self.get_scale(),
|
||||
"growth_factor": self._growth_factor,
|
||||
"backoff_factor": self._backoff_factor,
|
||||
"growth_interval": self._growth_interval,
|
||||
"_growth_tracker": self._get_growth_tracker()} if self._enabled else {}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
r"""
|
||||
Loads the scaler state. If this instance is disabled, :meth:`load_state_dict` is a no-op.
|
||||
|
||||
Args:
|
||||
state_dict(dict): scaler state. Should be an object returned from a call to :meth:`state_dict`.
|
||||
"""
|
||||
if not self._enabled:
|
||||
return
|
||||
|
||||
if len(state_dict) == 0:
|
||||
raise RuntimeError("The source state dict is empty, possibly because it was saved "
|
||||
"from a disabled instance of GradScaler.")
|
||||
|
||||
self._init_scale = state_dict["scale"]
|
||||
if self._scale is not None:
|
||||
self._scale.fill_(state_dict["scale"])
|
||||
self._growth_factor = state_dict["growth_factor"]
|
||||
self._backoff_factor = state_dict["backoff_factor"]
|
||||
self._growth_interval = state_dict["growth_interval"]
|
||||
self._init_growth_tracker = state_dict["_growth_tracker"]
|
||||
if self._growth_tracker is not None:
|
||||
self._growth_tracker.fill_(state_dict["_growth_tracker"])
|
||||
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
if self._enabled:
|
||||
assert len(self._per_optimizer_states) == 0, "A GradScaler instance may only be pickled at the beginning "\
|
||||
"of an iteration, or at the end after scaler.update()."
|
||||
# Pickling _scale and _growth_tracker Tensors directly triggers
|
||||
# "warnings.warn("pickle support for Storage will be removed in 1.5..."
|
||||
# so instead, we set the unpickled instance up to reinitialize them lazily.
|
||||
state['_init_scale'] = self.get_scale()
|
||||
state['_init_growth_tracker'] = self._get_growth_tracker()
|
||||
state['_scale'] = None
|
||||
state['_growth_tracker'] = None
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__dict__.update(state)
|
||||
|
||||
def _check_inf_per_device(self, optimizer):
|
||||
_scale, _ = self._check_scale_growth_tracker("_check_inf_per_device")
|
||||
|
||||
dummy_inv_scale = 1.0
|
||||
found_inf = 0.0
|
||||
|
||||
self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = \
|
||||
self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True)
|
||||
|
||||
return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]
|
||||
|
||||
def _found_inf_per_device(self, optimizer):
|
||||
return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]
|
|
@ -1,12 +0,0 @@
|
|||
import math
|
||||
|
||||
def _jit_set_profiling_mode(x): pass
|
||||
def _jit_set_profiling_executor(x): pass
|
||||
def _jit_override_can_fuse_on_cpu(x): pass
|
||||
def _jit_override_can_fuse_on_gpu(x): pass
|
||||
|
||||
def script(func):
|
||||
return func
|
||||
|
||||
inf = math.inf
|
||||
nan = math.nan
|
|
@ -1,281 +0,0 @@
|
|||
import jtorch
|
||||
from typing import List, Optional, Tuple, Iterable, Iterator, Mapping, Any, overload, TypeVar, Dict
|
||||
from typing_extensions import Self
|
||||
import jittor as jt
|
||||
from jtorch import make_module, Tensor, ModuleMisc, wrapper
|
||||
#from . import init
|
||||
from jittor import Function
|
||||
import operator
|
||||
import warnings
|
||||
|
||||
for k,v in jt.nn.__dict__.items():
|
||||
if callable(v):
|
||||
globals()[k] = wrapper(v)
|
||||
|
||||
for k,v in jt.nn.__dict__.items():
|
||||
if isinstance(v, type) and issubclass(v, jt.Module):
|
||||
globals()[k] = make_module(v)
|
||||
|
||||
from collections import OrderedDict
|
||||
from collections import abc as container_abcs
|
||||
|
||||
class Module(ModuleMisc, jt.Module):
|
||||
|
||||
def __call__(self, *args, **kw):
|
||||
return self.execute(*args, **kw)
|
||||
|
||||
def execute(self, *args, **kw):
|
||||
return self.forward(*args, **kw)
|
||||
|
||||
def get_submodule(self, target: str):
|
||||
if target == "":
|
||||
return self
|
||||
|
||||
atoms: List[str] = target.split(".")
|
||||
mod: jt.nn.Module = self
|
||||
|
||||
for item in atoms:
|
||||
if not hasattr(mod, item):
|
||||
raise AttributeError(mod._get_name() + " has no "
|
||||
"attribute `" + item + "`")
|
||||
|
||||
mod = getattr(mod, item)
|
||||
|
||||
if not isinstance(mod, jt.nn.Module):
|
||||
raise AttributeError("`" + item + "` is not "
|
||||
"an nn.Module")
|
||||
return mod
|
||||
|
||||
|
||||
|
||||
def Parameter(x:Tensor, requires_grad:bool=True) -> Tensor:
|
||||
x = x.clone()
|
||||
x.requires_grad = requires_grad
|
||||
x.retains_grad = requires_grad
|
||||
return x
|
||||
|
||||
def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False):
|
||||
return jt.nn.embedding(input, weight)
|
||||
|
||||
def dropout(x, p=0.5, training=False):
|
||||
return jt.nn.dropout(x, p, training)
|
||||
|
||||
|
||||
class Flatten(Module):
|
||||
''' Flattens the contiguous range of dimensions in a Var.
|
||||
:param start_dim: the first dimension to be flattened. Defaults: 1.
|
||||
:type start_dim: int
|
||||
:param end_dim: the last dimension to be flattened. Defaults: -1.
|
||||
:type end_dim: int
|
||||
'''
|
||||
def __init__(self, start_dim=1, end_dim=-1):
|
||||
self.start_dim = start_dim
|
||||
self.end_dim = end_dim
|
||||
|
||||
def forward(self, x) -> jt.Var:
|
||||
return x.flatten(self.start_dim, self.end_dim)
|
||||
|
||||
class _IncompatibleKeys:
|
||||
def __init__(self, missing_keys, unexpected_keys):
|
||||
self.missing_keys = missing_keys
|
||||
self.unexpected_keys = unexpected_keys
|
||||
|
||||
_BatchNorm = None
|
||||
|
||||
#from . import utils
|
||||
normalize = wrapper(jt.normalize)
|
||||
|
||||
T = TypeVar('T', bound=Module)
|
||||
|
||||
class ModuleDict(Module):
|
||||
_modules: Dict[str, Module] # type: ignore[assignment]
|
||||
|
||||
def __init__(self, modules: Optional[Mapping[str, Module]] = None) -> None:
|
||||
super().__init__()
|
||||
if modules is not None:
|
||||
self.update(modules)
|
||||
|
||||
def __getitem__(self, key: str) -> Module:
|
||||
return self._modules[key]
|
||||
|
||||
def __setitem__(self, key: str, module: Module) -> None:
|
||||
self.add_module(key, module)
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
del self._modules[key]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._modules)
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
return iter(self._modules)
|
||||
|
||||
def __contains__(self, key: str) -> bool:
|
||||
return key in self._modules
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Remove all items from the ModuleDict."""
|
||||
self._modules.clear()
|
||||
|
||||
def pop(self, key: str) -> Module:
|
||||
r"""Remove key from the ModuleDict and return its module.
|
||||
|
||||
Args:
|
||||
key (str): key to pop from the ModuleDict
|
||||
"""
|
||||
v = self[key]
|
||||
del self[key]
|
||||
return v
|
||||
|
||||
def keys(self) -> Iterable[str]:
|
||||
r"""Return an iterable of the ModuleDict keys."""
|
||||
return self._modules.keys()
|
||||
|
||||
def items(self) -> Iterable[Tuple[str, Module]]:
|
||||
r"""Return an iterable of the ModuleDict key/value pairs."""
|
||||
return self._modules.items()
|
||||
|
||||
def values(self) -> Iterable[Module]:
|
||||
r"""Return an iterable of the ModuleDict values."""
|
||||
return self._modules.values()
|
||||
|
||||
def update(self, modules: Mapping[str, Module]) -> None:
|
||||
r"""Update the :class:`~torch.nn.ModuleDict` with key-value pairs from a mapping, overwriting existing keys.
|
||||
|
||||
.. note::
|
||||
If :attr:`modules` is an ``OrderedDict``, a :class:`~torch.nn.ModuleDict`, or
|
||||
an iterable of key-value pairs, the order of new elements in it is preserved.
|
||||
|
||||
Args:
|
||||
modules (iterable): a mapping (dictionary) from string to :class:`~torch.nn.Module`,
|
||||
or an iterable of key-value pairs of type (string, :class:`~torch.nn.Module`)
|
||||
"""
|
||||
if not isinstance(modules, container_abcs.Iterable):
|
||||
raise TypeError("ModuleDict.update should be called with an "
|
||||
"iterable of key/value pairs, but got " +
|
||||
type(modules).__name__)
|
||||
|
||||
if isinstance(modules, (OrderedDict, ModuleDict, container_abcs.Mapping)):
|
||||
for key, module in modules.items():
|
||||
self[key] = module
|
||||
else:
|
||||
# modules here can be a list with two items
|
||||
for j, m in enumerate(modules):
|
||||
if not isinstance(m, container_abcs.Iterable):
|
||||
raise TypeError("ModuleDict update sequence element "
|
||||
"#" + str(j) + " should be Iterable; is" +
|
||||
type(m).__name__)
|
||||
if not len(m) == 2:
|
||||
raise ValueError("ModuleDict update sequence element "
|
||||
"#" + str(j) + " has length " + str(len(m)) +
|
||||
"; 2 is required")
|
||||
# modules can be Mapping (what it's typed at), or a list: [(name1, module1), (name2, module2)]
|
||||
# that's too cumbersome to type correctly with overloads, so we add an ignore here
|
||||
self[m[0]] = m[1] # type: ignore[assignment]
|
||||
|
||||
# remove forward alltogether to fallback on Module's _forward_unimplemented
|
||||
|
||||
|
||||
class ParameterList(Module):
|
||||
|
||||
def __init__(self, values: Optional[Iterable[Any]] = None) -> None:
|
||||
super().__init__()
|
||||
self._size = 0
|
||||
if values is not None:
|
||||
self += values
|
||||
|
||||
def _get_abs_string_index(self, idx):
|
||||
"""Get the absolute index for the list of modules."""
|
||||
idx = operator.index(idx)
|
||||
if not (-len(self) <= idx < len(self)):
|
||||
raise IndexError(f'index {idx} is out of range')
|
||||
if idx < 0:
|
||||
idx += len(self)
|
||||
return str(idx)
|
||||
|
||||
@overload
|
||||
def __getitem__(self, idx: int) -> Any:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __getitem__(self: T, idx: slice) -> T:
|
||||
...
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if isinstance(idx, slice):
|
||||
start, stop, step = idx.indices(len(self))
|
||||
out = self.__class__()
|
||||
for i in range(start, stop, step):
|
||||
out.append(self[i])
|
||||
return out
|
||||
else:
|
||||
idx = self._get_abs_string_index(idx)
|
||||
return getattr(self, str(idx))
|
||||
|
||||
def __setitem__(self, idx: int, param: Any) -> None:
|
||||
# Note that all other function that add an entry to the list part of
|
||||
# the ParameterList end up here. So this is the only place where we need
|
||||
# to wrap things into Parameter if needed.
|
||||
# Objects added via setattr() are not in the list part and thus won't
|
||||
# call into this function.
|
||||
idx = self._get_abs_string_index(idx)
|
||||
if isinstance(param, jt.Var) and not isinstance(param, Parameter):
|
||||
param = Parameter(param)
|
||||
return setattr(self, str(idx), param)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self._size
|
||||
|
||||
def __iter__(self) -> Iterator[Any]:
|
||||
return iter(self[i] for i in range(len(self)))
|
||||
|
||||
def __iadd__(self, parameters: Iterable[Any]) -> Self:
|
||||
return self.extend(parameters)
|
||||
|
||||
def __dir__(self):
|
||||
keys = super().__dir__()
|
||||
keys = [key for key in keys if not key.isdigit()]
|
||||
return keys
|
||||
|
||||
def append(self, value: Any) -> 'ParameterList':
|
||||
"""Append a given value at the end of the list.
|
||||
|
||||
Args:
|
||||
value (Any): value to append
|
||||
"""
|
||||
new_idx = len(self)
|
||||
self._size += 1
|
||||
self[new_idx] = value
|
||||
return self
|
||||
|
||||
def extend(self, values: Iterable[Any]) -> Self:
|
||||
"""Append values from a Python iterable to the end of the list.
|
||||
|
||||
Args:
|
||||
values (iterable): iterable of values to append
|
||||
"""
|
||||
# Tensor is an iterable but we never want to unpack it here
|
||||
if not isinstance(values, container_abcs.Iterable) or isinstance(values, jt.Var):
|
||||
raise TypeError("ParameterList.extend should be called with an "
|
||||
"iterable, but got " + type(values).__name__)
|
||||
for value in values:
|
||||
self.append(value)
|
||||
return self
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
child_lines = []
|
||||
for k, p in enumerate(self):
|
||||
if isinstance(p, jt.Var):
|
||||
size_str = 'x'.join(str(size) for size in p.size())
|
||||
parastr = '{} containing: [{} of size {}{}]'.format(
|
||||
"Parameter" if isinstance(p, Parameter) else "Tensor",
|
||||
p.dtype, size_str, "cuda" if jt.flags.use_cuda else "cpu")
|
||||
child_lines.append(' (' + str(k) + '): ' + parastr)
|
||||
else:
|
||||
child_lines.append(' (' + str(k) + '): Object of type: ' + type(p).__name__)
|
||||
|
||||
tmpstr = '\n'.join(child_lines)
|
||||
return tmpstr
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise RuntimeError('ParameterList should not be called.')
|
|
@ -1,16 +0,0 @@
|
|||
import jittor as jt
|
||||
|
||||
for k,v in jt.nn.init.__dict__.items():
|
||||
if callable(v):
|
||||
globals()[k] = v
|
||||
|
||||
|
||||
normal = gauss
|
||||
normal_ = gauss_
|
||||
xavier_normal = xavier_gauss
|
||||
xavier_normal_ = xavier_gauss_
|
||||
zeros_ = zero_
|
||||
|
||||
|
||||
jt.Var.normal_ = normal_
|
||||
|
|
@ -1 +0,0 @@
|
|||
from . import rnn
|
|
@ -1,20 +0,0 @@
|
|||
import jittor as jt
|
||||
|
||||
PackedSequence = None
|
||||
|
||||
def pad_sequence(sequences,batch_first=False,padding_value=0.0):
|
||||
max_f = max([len(s) for s in sequences])
|
||||
# max_f = 512
|
||||
b = len(sequences)
|
||||
if batch_first:
|
||||
ret = sequences[0].new_full([b,max_f,]+list(sequences[0].shape[1:]),padding_value)
|
||||
for i,s in enumerate(sequences):
|
||||
ret[i,:len(s)] = s
|
||||
else:
|
||||
ret = sequences[0].new_full([max_f,b,]+list(sequences[0].shape[1:]),padding_value)
|
||||
for i,s in enumerate(sequences):
|
||||
ret[:len(s),i] = s
|
||||
# print(ret.shape)
|
||||
# ret = ret[:,:406]
|
||||
return ret
|
||||
|
File diff suppressed because it is too large
Load Diff
|
@ -1,102 +0,0 @@
|
|||
|
||||
#include "pyjt/py_obj_holder.h"
|
||||
#include "utils/str_utils.h"
|
||||
#include "jtorch_core.h"
|
||||
#include "graph.h"
|
||||
#include "grad.h"
|
||||
#include "ops/op_register.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
void pyjt_def_all(PyObject* m);
|
||||
|
||||
EXTERN_LIB void setter_use_cuda(int value);
|
||||
|
||||
Device::Device(const string& name, int ordinal) : name(name) {
|
||||
if (startswith(name, "cpu"))
|
||||
setter_use_cuda(0);
|
||||
else
|
||||
setter_use_cuda(1);
|
||||
}
|
||||
|
||||
unordered_map<int64, VarPtr> grad_backup;
|
||||
EXTERN_LIB void (*_var_free_hook)(Var*);
|
||||
EXTERN_LIB unordered_map<int64, VarPtr>* _grad_backup_ptr;
|
||||
|
||||
void jtorch_var_free_hook(Var* v) {
|
||||
auto iter = grad_backup.find(v->id);
|
||||
if (iter != grad_backup.end()) {
|
||||
grad_backup.erase(iter);
|
||||
}
|
||||
}
|
||||
|
||||
void jtorch_init() {
|
||||
_var_free_hook = &jtorch_var_free_hook;
|
||||
_grad_backup_ptr = &grad_backup;
|
||||
}
|
||||
|
||||
inline static VarPtr& get_grad(Var* v) {
|
||||
return grad_backup[v->id];
|
||||
}
|
||||
static auto make_binary = get_op_info("binary")
|
||||
.get_constructor<VarPtr, Var*, Var*, NanoString>();
|
||||
|
||||
inline static void add_grad(VarPtr& a, VarPtr&& b) {
|
||||
if (!a) a = move(b);
|
||||
else {
|
||||
a = make_binary(a, b, ns_add);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void grad_set(VarHolder* x, Maybe<VarHolder> v) {
|
||||
if (!v) {
|
||||
grad_del(x);
|
||||
return;
|
||||
}
|
||||
grad_backup[x->var->id] = v.ptr->var;
|
||||
}
|
||||
|
||||
Maybe<VarHolder> grad_get(VarHolder* x) {
|
||||
auto iter = grad_backup.find(x->var->id);
|
||||
if (iter != grad_backup.end()) {
|
||||
if (!iter->second.ptr) return nullptr;
|
||||
return new VarHolder(iter->second.ptr);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void grad_del(VarHolder* x) {
|
||||
auto iter = grad_backup.find(x->var->id);
|
||||
if (iter != grad_backup.end())
|
||||
grad_backup.erase(iter);
|
||||
}
|
||||
|
||||
void backward(VarHolder* x) {
|
||||
vector<Node*> gnodes({x->var});
|
||||
bfs_backward(gnodes, [&](Node* node) {
|
||||
if (node->is_stop_grad())
|
||||
return false;
|
||||
return true;
|
||||
});
|
||||
vector<Var*> targets;
|
||||
for (auto* node : gnodes) {
|
||||
if (node->is_var() && node->flags.get(NodeFlags::_th_require_grad))
|
||||
targets.push_back(node->var());
|
||||
}
|
||||
auto grads = grad(x->var, targets);
|
||||
for (int i=0; i<targets.size(); i++) {
|
||||
auto& gptr = get_grad(targets[i]);
|
||||
add_grad(gptr, move(grads[i]));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
static void init_module(PyModuleDef* mdef, PyObject* m) {
|
||||
jittor::jtorch_init();
|
||||
mdef->m_doc = "Inner c++ core of jtorch";
|
||||
jittor::pyjt_def_all(m);
|
||||
}
|
||||
PYJT_MODULE_INIT(jtorch_core);
|
|
@ -1,40 +0,0 @@
|
|||
#pragma once
|
||||
#include "common.h"
|
||||
#include "var_holder.h"
|
||||
#include "misc/fast_shared_ptr.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
// @pyjt(device)
|
||||
// @attrs(heaptype)
|
||||
struct Device {
|
||||
string name;
|
||||
|
||||
// @pyjt(__init__)
|
||||
Device(const string& name, int ordinal=0);
|
||||
// @pyjt(__get__type, __str__)
|
||||
inline string get_type() {return name;}
|
||||
// @pyjt(__get__index)
|
||||
inline int index() {return 0;}
|
||||
};
|
||||
|
||||
// @pyjt(backward)
|
||||
void backward(VarHolder* x);
|
||||
|
||||
// @pyjt(grad_set)
|
||||
void grad_set(VarHolder* x, Maybe<VarHolder> v);
|
||||
// @pyjt(grad_get)
|
||||
Maybe<VarHolder> grad_get(VarHolder* x);
|
||||
// @pyjt(grad_del)
|
||||
void grad_del(VarHolder* x);
|
||||
|
||||
// @pyjt(retain_grad_set)
|
||||
inline void retain_grad_set(VarHolder* x, bool v) {
|
||||
x->var->flags.set(NodeFlags::_th_require_grad, v);
|
||||
}
|
||||
// @pyjt(retain_grad_get)
|
||||
inline bool retain_grad_get(VarHolder* x) {
|
||||
return x->var->flags.get(NodeFlags::_th_require_grad);
|
||||
}
|
||||
|
||||
}
|
|
@ -1,25 +0,0 @@
|
|||
import unittest
|
||||
import numpy as np
|
||||
import torch
|
||||
import jittor as jt
|
||||
|
||||
class TestConflictFunc(unittest.TestCase):
|
||||
def test_max(self):
|
||||
a = torch.Tensor([1,4,2])
|
||||
assert a.max() == 4
|
||||
v, k = a.max(dim=0)
|
||||
assert v==4 and k==1
|
||||
|
||||
def test_argsort(self):
|
||||
a = torch.Tensor([1,4,2])
|
||||
k = a.argsort()
|
||||
assert jt.all_equal(k, [0,2,1])
|
||||
|
||||
with jt.flag_scope(th_mode=0):
|
||||
k, v = a.argsort()
|
||||
assert jt.all_equal(k, [0,2,1])
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -1,58 +0,0 @@
|
|||
import unittest
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
class TestFunction(unittest.TestCase):
|
||||
def test_example1(self):
|
||||
import jtorch
|
||||
from jtorch import Function
|
||||
|
||||
class MyFunc(Function):
|
||||
@staticmethod
|
||||
def forward(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
return x*y, x/y
|
||||
|
||||
@staticmethod
|
||||
def backward(self, grad0, grad1):
|
||||
return grad0 * self.y, grad1 * self.x
|
||||
|
||||
a = jtorch.array(3.0)
|
||||
a.requires_grad = True
|
||||
b = jtorch.array(4.0)
|
||||
b.requires_grad = True
|
||||
func = MyFunc.apply
|
||||
c,d = func(a, b)
|
||||
(c+d*3).backward()
|
||||
assert a.grad.data == 4
|
||||
assert b.grad.data == 9
|
||||
|
||||
def test_example2(self):
|
||||
import jtorch as jt
|
||||
from jtorch import Function
|
||||
|
||||
class MyFunc(Function):
|
||||
@staticmethod
|
||||
def forward(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
return x*y, x/y
|
||||
|
||||
@staticmethod
|
||||
def backward(self, grad0, grad1):
|
||||
assert grad1 is None
|
||||
return grad0 * self.y, None
|
||||
a = jt.array(3.0)
|
||||
a.requires_grad = True
|
||||
b = jt.array(4.0)
|
||||
b.requires_grad = True
|
||||
func = MyFunc.apply
|
||||
c,d = func(a, b)
|
||||
d.stop_grad()
|
||||
da, db = jt.grad(c+d*3, [a, b])
|
||||
assert da.data == 4
|
||||
assert db.data == 0
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -1,24 +0,0 @@
|
|||
import unittest
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
class TestMisc(unittest.TestCase):
|
||||
def test_update_grad(self):
|
||||
class Net(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.a = torch.nn.Parameter(torch.Tensor([1.0, 2.0]))
|
||||
net = Net()
|
||||
assert(net.a.requires_grad)
|
||||
net.load_state_dict({"a": torch.Tensor([3.0, 4.0])})
|
||||
assert(net.a.requires_grad)
|
||||
|
||||
def test_reshape(self):
|
||||
a = torch.ones(3,3)
|
||||
a.requires_grad = True
|
||||
b = torch.reshape(a, [9])
|
||||
assert b.requires_grad == True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -1,56 +0,0 @@
|
|||
import unittest
|
||||
import numpy as np
|
||||
import os
|
||||
import subprocess as sp
|
||||
import sys
|
||||
|
||||
def check_two(cmd, parser=None, checker=None):
|
||||
jtorch_out = sp.getoutput(cmd)
|
||||
print("=========JTORCH OUT==========")
|
||||
print(jtorch_out)
|
||||
torch_out = sp.getoutput("PYTHONPATH= "+cmd)
|
||||
print("=========TORCH OUT==========")
|
||||
print(torch_out)
|
||||
if parser:
|
||||
torch_out = parser(torch_out)
|
||||
jtorch_out = parser(jtorch_out)
|
||||
if checker:
|
||||
checker(torch_out, jtorch_out)
|
||||
else:
|
||||
assert torch_out == jtorch_out
|
||||
return jtorch_out, torch_out
|
||||
|
||||
jtorch_path = os.path.join(os.path.dirname(__file__), "..")
|
||||
# come from https://pytorch.org/tutorials/beginner/pytorch_with_examples.html
|
||||
class TestTutorial(unittest.TestCase):
|
||||
def test_auto_grad1(self):
|
||||
check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad1.py",
|
||||
parser=lambda s: np.array(s.split())[[-10,-8,-5,-2]].astype(float),
|
||||
checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-4))
|
||||
def test_auto_grad2(self):
|
||||
check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad2.py",
|
||||
parser=lambda s: np.array(s.split())[[-10,-8,-5,-2]].astype(float),
|
||||
checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-4))
|
||||
def test_auto_grad3(self):
|
||||
check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad3.py",
|
||||
parser=lambda s: np.array(s.split())[[-9,-7,-4,-2]].astype(float),
|
||||
checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-4))
|
||||
def test_auto_grad4(self):
|
||||
check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad4.py",
|
||||
parser=lambda s: np.array(s.split())[[-10,-8,-5,-2]].astype(float),
|
||||
checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-4))
|
||||
def test_auto_grad5(self):
|
||||
check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad5_optim.py",
|
||||
parser=lambda s: np.array(s.split())[[-10,-8,-5,-2]].astype(float),
|
||||
checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-2))
|
||||
def test_auto_grad6(self):
|
||||
check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad6_module.py",
|
||||
parser=lambda s: np.array(s.split())[[-10,-8,-5,-2]].astype(float),
|
||||
checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-4))
|
||||
def test_auto_grad7(self):
|
||||
check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad7_dynet.py",
|
||||
parser=lambda s: np.array(s.split())[[-13,-10,-7,-3]].astype(float),
|
||||
checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-2))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -1,44 +0,0 @@
|
|||
import torch
|
||||
import math
|
||||
|
||||
dtype = torch.float
|
||||
device = torch.device("cpu")
|
||||
# device = torch.device("cuda:0") # Uncomment this to run on GPU
|
||||
|
||||
# Create random input and output data
|
||||
x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
|
||||
y = torch.sin(x)
|
||||
|
||||
# Randomly initialize weights
|
||||
a = torch.randn((), device=device, dtype=dtype)
|
||||
b = torch.randn((), device=device, dtype=dtype)
|
||||
c = torch.randn((), device=device, dtype=dtype)
|
||||
d = torch.randn((), device=device, dtype=dtype)
|
||||
|
||||
learning_rate = 1e-6
|
||||
for t in range(20000):
|
||||
# Forward pass: compute predicted y
|
||||
y_pred = a + b * x + c * x ** 2 + d * x ** 3
|
||||
|
||||
# Compute and print loss
|
||||
loss = (y_pred - y).pow(2).sum().item()
|
||||
if t % 1000 == 999:
|
||||
print(t, loss)
|
||||
|
||||
# Backprop to compute gradients of a, b, c, d with respect to loss
|
||||
grad_y_pred = 2.0 * (y_pred - y)
|
||||
grad_a = grad_y_pred.sum()
|
||||
grad_b = (grad_y_pred * x).sum()
|
||||
grad_c = (grad_y_pred * x ** 2).sum()
|
||||
grad_d = (grad_y_pred * x ** 3).sum()
|
||||
|
||||
# Update weights using gradient descent
|
||||
a -= learning_rate * grad_a
|
||||
b -= learning_rate * grad_b
|
||||
c -= learning_rate * grad_c
|
||||
d -= learning_rate * grad_d
|
||||
# print(t, torch.liveness_info())
|
||||
# torch.sync_all()
|
||||
|
||||
|
||||
print(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3')
|
|
@ -1,60 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
import torch
|
||||
import math
|
||||
|
||||
dtype = torch.float
|
||||
device = torch.device("cpu")
|
||||
# device = torch.device("cuda:0") # Uncomment this to run on GPU
|
||||
|
||||
# Create Tensors to hold input and outputs.
|
||||
# By default, requires_grad=False, which indicates that we do not need to
|
||||
# compute gradients with respect to these Tensors during the backward pass.
|
||||
x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
|
||||
y = torch.sin(x)
|
||||
|
||||
# Create random Tensors for weights. For a third order polynomial, we need
|
||||
# 4 weights: y = a + b x + c x^2 + d x^3
|
||||
# Setting requires_grad=True indicates that we want to compute gradients with
|
||||
# respect to these Tensors during the backward pass.
|
||||
a = torch.randn((), device=device, dtype=dtype, requires_grad=True)
|
||||
b = torch.randn((), device=device, dtype=dtype, requires_grad=True)
|
||||
c = torch.randn((), device=device, dtype=dtype, requires_grad=True)
|
||||
d = torch.randn((), device=device, dtype=dtype, requires_grad=True)
|
||||
|
||||
learning_rate = 1e-6
|
||||
for t in range(20000):
|
||||
# Forward pass: compute predicted y using operations on Tensors.
|
||||
y_pred = a + b * x + c * x ** 2 + d * x ** 3
|
||||
# print(y_pred.requires_grad)
|
||||
# y_pred.requires_grad = False
|
||||
|
||||
# Compute and print loss using operations on Tensors.
|
||||
# Now loss is a Tensor of shape (1,)
|
||||
# loss.item() gets the scalar value held in the loss.
|
||||
loss = (y_pred - y).pow(2).sum()
|
||||
if t % 1000 == 990:
|
||||
print(t, loss.item())
|
||||
|
||||
# Use autograd to compute the backward pass. This call will compute the
|
||||
# gradient of loss with respect to all Tensors with requires_grad=True.
|
||||
# After this call a.grad, b.grad. c.grad and d.grad will be Tensors holding
|
||||
# the gradient of the loss with respect to a, b, c, d respectively.
|
||||
# torch.backward(loss)
|
||||
loss.backward()
|
||||
|
||||
# Manually update weights using gradient descent. Wrap in torch.no_grad()
|
||||
# because weights have requires_grad=True, but we don't need to track this
|
||||
# in autograd.
|
||||
with torch.no_grad():
|
||||
a -= learning_rate * a.grad
|
||||
b -= learning_rate * b.grad
|
||||
c -= learning_rate * c.grad
|
||||
d -= learning_rate * d.grad
|
||||
|
||||
# Manually zero the gradients after updating weights
|
||||
a.grad = None
|
||||
b.grad = None
|
||||
c.grad = None
|
||||
d.grad = None
|
||||
|
||||
print(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3')
|
|
@ -1,85 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
import torch
|
||||
import math
|
||||
|
||||
|
||||
class LegendrePolynomial3(torch.autograd.Function):
|
||||
"""
|
||||
We can implement our own custom autograd Functions by subclassing
|
||||
torch.autograd.Function and implementing the forward and backward passes
|
||||
which operate on Tensors.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
"""
|
||||
In the forward pass we receive a Tensor containing the input and return
|
||||
a Tensor containing the output. ctx is a context object that can be used
|
||||
to stash information for backward computation. You can cache arbitrary
|
||||
objects for use in the backward pass using the ctx.save_for_backward method.
|
||||
"""
|
||||
ctx.save_for_backward(input)
|
||||
return 0.5 * (5 * input ** 3 - 3 * input)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
"""
|
||||
In the backward pass we receive a Tensor containing the gradient of the loss
|
||||
with respect to the output, and we need to compute the gradient of the loss
|
||||
with respect to the input.
|
||||
"""
|
||||
input, = ctx.saved_tensors
|
||||
return grad_output * 1.5 * (5 * input ** 2 - 1)
|
||||
|
||||
|
||||
dtype = torch.float
|
||||
device = torch.device("cpu")
|
||||
# device = torch.device("cuda:0") # Uncomment this to run on GPU
|
||||
|
||||
# Create Tensors to hold input and outputs.
|
||||
# By default, requires_grad=False, which indicates that we do not need to
|
||||
# compute gradients with respect to these Tensors during the backward pass.
|
||||
x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
|
||||
y = torch.sin(x)
|
||||
|
||||
# Create random Tensors for weights. For this example, we need
|
||||
# 4 weights: y = a + b * P3(c + d * x), these weights need to be initialized
|
||||
# not too far from the correct result to ensure convergence.
|
||||
# Setting requires_grad=True indicates that we want to compute gradients with
|
||||
# respect to these Tensors during the backward pass.
|
||||
a = torch.full((), 0.0, device=device, dtype=dtype, requires_grad=True)
|
||||
b = torch.full((), -1.0, device=device, dtype=dtype, requires_grad=True)
|
||||
c = torch.full((), 0.0, device=device, dtype=dtype, requires_grad=True)
|
||||
d = torch.full((), 0.3, device=device, dtype=dtype, requires_grad=True)
|
||||
|
||||
learning_rate = 5e-6
|
||||
for t in range(2000):
|
||||
# To apply our Function, we use Function.apply method. We alias this as 'P3'.
|
||||
P3 = LegendrePolynomial3.apply
|
||||
|
||||
# Forward pass: compute predicted y using operations; we compute
|
||||
# P3 using our custom autograd operation.
|
||||
y_pred = a + b * P3(c + d * x)
|
||||
|
||||
# Compute and print loss
|
||||
loss = (y_pred - y).pow(2).sum()
|
||||
if t % 100 == 99:
|
||||
print(t, loss.item())
|
||||
|
||||
# Use autograd to compute the backward pass.
|
||||
loss.backward()
|
||||
|
||||
# Update weights using gradient descent
|
||||
with torch.no_grad():
|
||||
a -= learning_rate * a.grad
|
||||
b -= learning_rate * b.grad
|
||||
c -= learning_rate * c.grad
|
||||
d -= learning_rate * d.grad
|
||||
|
||||
# Manually zero the gradients after updating weights
|
||||
a.grad = None
|
||||
b.grad = None
|
||||
c.grad = None
|
||||
d.grad = None
|
||||
|
||||
print(f'Result: y = {a.item()} + {b.item()} * P3( {c.item()} + {d.item()} x)')
|
|
@ -1,71 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
import torch
|
||||
import math
|
||||
|
||||
|
||||
# Create Tensors to hold input and outputs.
|
||||
x = torch.linspace(-math.pi, math.pi, 2000)
|
||||
y = torch.sin(x)
|
||||
|
||||
# For this example, the output y is a linear function of (x, x^2, x^3), so
|
||||
# we can consider it as a linear layer neural network. Let's prepare the
|
||||
# tensor (x, x^2, x^3).
|
||||
p = torch.tensor([1, 2, 3])
|
||||
xx = x.unsqueeze(-1).pow(p)
|
||||
|
||||
# In the above code, x.unsqueeze(-1) has shape (2000, 1), and p has shape
|
||||
# (3,), for this case, broadcasting semantics will apply to obtain a tensor
|
||||
# of shape (2000, 3)
|
||||
|
||||
# Use the nn package to define our model as a sequence of layers. nn.Sequential
|
||||
# is a Module which contains other Modules, and applies them in sequence to
|
||||
# produce its output. The Linear Module computes output from input using a
|
||||
# linear function, and holds internal Tensors for its weight and bias.
|
||||
# The Flatten layer flatens the output of the linear layer to a 1D tensor,
|
||||
# to match the shape of `y`.
|
||||
model = torch.nn.Sequential(
|
||||
torch.nn.Linear(3, 1),
|
||||
torch.nn.Flatten(0, 1)
|
||||
)
|
||||
|
||||
# The nn package also contains definitions of popular loss functions; in this
|
||||
# case we will use Mean Squared Error (MSE) as our loss function.
|
||||
loss_fn = torch.nn.MSELoss(reduction='sum')
|
||||
# print(model[0].weight.requires_grad)
|
||||
|
||||
learning_rate = 1e-6
|
||||
for t in range(8000):
|
||||
|
||||
# Forward pass: compute predicted y by passing x to the model. Module objects
|
||||
# override the __call__ operator so you can call them like functions. When
|
||||
# doing so you pass a Tensor of input data to the Module and it produces
|
||||
# a Tensor of output data.
|
||||
y_pred = model(xx)
|
||||
|
||||
# Compute and print loss. We pass Tensors containing the predicted and true
|
||||
# values of y, and the loss function returns a Tensor containing the
|
||||
# loss.
|
||||
loss = loss_fn(y_pred, y)
|
||||
if t % 1000 == 999:
|
||||
print(t, loss.item())
|
||||
|
||||
# Zero the gradients before running the backward pass.
|
||||
model.zero_grad()
|
||||
|
||||
# Backward pass: compute gradient of the loss with respect to all the learnable
|
||||
# parameters of the model. Internally, the parameters of each Module are stored
|
||||
# in Tensors with requires_grad=True, so this call will compute gradients for
|
||||
# all learnable parameters in the model.
|
||||
loss.backward()
|
||||
|
||||
# Update the weights using gradient descent. Each parameter is a Tensor, so
|
||||
# we can access its gradients like we did before.
|
||||
with torch.no_grad():
|
||||
for param in model.parameters():
|
||||
param -= learning_rate * param.grad
|
||||
|
||||
# You can access the first layer of `model` like accessing the first item of a list
|
||||
linear_layer = model[0]
|
||||
|
||||
# For linear layer, its parameters are stored as `weight` and `bias`.
|
||||
print(f'Result: y = {linear_layer.bias.item()} + {linear_layer.weight[:, 0].item()} x + {linear_layer.weight[:, 1].item()} x^2 + {linear_layer.weight[:, 2].item()} x^3')
|
|
@ -1,53 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
import torch
|
||||
import math
|
||||
|
||||
|
||||
# Create Tensors to hold input and outputs.
|
||||
x = torch.linspace(-math.pi, math.pi, 2000)
|
||||
y = torch.sin(x)
|
||||
|
||||
# Prepare the input tensor (x, x^2, x^3).
|
||||
p = torch.tensor([1, 2, 3])
|
||||
xx = x.unsqueeze(-1).pow(p)
|
||||
|
||||
# Use the nn package to define our model and loss function.
|
||||
model = torch.nn.Sequential(
|
||||
torch.nn.Linear(3, 1),
|
||||
torch.nn.Flatten(0, 1)
|
||||
)
|
||||
loss_fn = torch.nn.MSELoss(reduction='sum')
|
||||
|
||||
# Use the optim package to define an Optimizer that will update the weights of
|
||||
# the model for us. Here we will use RMSprop; the optim package contains many other
|
||||
# optimization algorithms. The first argument to the RMSprop constructor tells the
|
||||
# optimizer which Tensors it should update.
|
||||
learning_rate = 1e-3
|
||||
optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate)
|
||||
for t in range(8000):
|
||||
# Forward pass: compute predicted y by passing x to the model.
|
||||
y_pred = model(xx)
|
||||
|
||||
# Compute and print loss.
|
||||
loss = loss_fn(y_pred, y)
|
||||
if t % 1000 == 999:
|
||||
print(t, loss.item())
|
||||
|
||||
# Before the backward pass, use the optimizer object to zero all of the
|
||||
# gradients for the variables it will update (which are the learnable
|
||||
# weights of the model). This is because by default, gradients are
|
||||
# accumulated in buffers( i.e, not overwritten) whenever .backward()
|
||||
# is called. Checkout docs of torch.autograd.backward for more details.
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Backward pass: compute gradient of the loss with respect to model
|
||||
# parameters
|
||||
loss.backward()
|
||||
|
||||
# Calling the step function on an Optimizer makes an update to its
|
||||
# parameters
|
||||
optimizer.step()
|
||||
|
||||
|
||||
linear_layer = model[0]
|
||||
print(f'Result: y = {linear_layer.bias.item()} + {linear_layer.weight[:, 0].item()} x + {linear_layer.weight[:, 1].item()} x^2 + {linear_layer.weight[:, 2].item()} x^3')
|
|
@ -1,59 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
import torch
|
||||
import math
|
||||
|
||||
|
||||
class Polynomial3(torch.nn.Module):
|
||||
def __init__(self):
|
||||
"""
|
||||
In the constructor we instantiate four parameters and assign them as
|
||||
member parameters.
|
||||
"""
|
||||
super().__init__()
|
||||
self.a = torch.nn.Parameter(torch.randn(()))
|
||||
self.b = torch.nn.Parameter(torch.randn(()))
|
||||
self.c = torch.nn.Parameter(torch.randn(()))
|
||||
self.d = torch.nn.Parameter(torch.randn(()))
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
In the forward function we accept a Tensor of input data and we must return
|
||||
a Tensor of output data. We can use Modules defined in the constructor as
|
||||
well as arbitrary operators on Tensors.
|
||||
"""
|
||||
return self.a + self.b * x + self.c * x ** 2 + self.d * x ** 3
|
||||
|
||||
def string(self):
|
||||
"""
|
||||
Just like any class in Python, you can also define custom method on PyTorch modules
|
||||
"""
|
||||
return f'y = {self.a.item()} + {self.b.item()} x + {self.c.item()} x^2 + {self.d.item()} x^3'
|
||||
|
||||
|
||||
# Create Tensors to hold input and outputs.
|
||||
x = torch.linspace(-math.pi, math.pi, 2000)
|
||||
y = torch.sin(x)
|
||||
|
||||
# Construct our model by instantiating the class defined above
|
||||
model = Polynomial3()
|
||||
|
||||
# Construct our loss function and an Optimizer. The call to model.parameters()
|
||||
# in the SGD constructor will contain the learnable parameters (defined
|
||||
# with torch.nn.Parameter) which are members of the model.
|
||||
criterion = torch.nn.MSELoss(reduction='sum')
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=1e-6)
|
||||
for t in range(8000):
|
||||
# Forward pass: Compute predicted y by passing x to the model
|
||||
y_pred = model(x)
|
||||
|
||||
# Compute and print loss
|
||||
loss = criterion(y_pred, y)
|
||||
if t % 1000 == 999:
|
||||
print(t, loss.item())
|
||||
|
||||
# Zero gradients, perform a backward pass, and update the weights.
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
print(f'Result: {model.string()}')
|
|
@ -1,69 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
import random
|
||||
import torch
|
||||
import math
|
||||
|
||||
|
||||
class DynamicNet(torch.nn.Module):
|
||||
def __init__(self):
|
||||
"""
|
||||
In the constructor we instantiate five parameters and assign them as members.
|
||||
"""
|
||||
super().__init__()
|
||||
self.a = torch.nn.Parameter(torch.randn(()))
|
||||
self.b = torch.nn.Parameter(torch.randn(()))
|
||||
self.c = torch.nn.Parameter(torch.randn(()))
|
||||
self.d = torch.nn.Parameter(torch.randn(()))
|
||||
self.e = torch.nn.Parameter(torch.randn(()))
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
For the forward pass of the model, we randomly choose either 4, 5
|
||||
and reuse the e parameter to compute the contribution of these orders.
|
||||
|
||||
Since each forward pass builds a dynamic computation graph, we can use normal
|
||||
Python control-flow operators like loops or conditional statements when
|
||||
defining the forward pass of the model.
|
||||
|
||||
Here we also see that it is perfectly safe to reuse the same parameter many
|
||||
times when defining a computational graph.
|
||||
"""
|
||||
y = self.a + self.b * x + self.c * x ** 2 + self.d * x ** 3
|
||||
for exp in range(4, random.randint(4, 6)):
|
||||
y = y + self.e * x ** exp
|
||||
return y
|
||||
|
||||
def string(self):
|
||||
"""
|
||||
Just like any class in Python, you can also define custom method on PyTorch modules
|
||||
"""
|
||||
return f'y = {self.a.item()} + {self.b.item()} x + {self.c.item()} x^2 + {self.d.item()} x^3 + {self.e.item()} x^4 ? + {self.e.item()} x^5 ?'
|
||||
|
||||
|
||||
# Create Tensors to hold input and outputs.
|
||||
x = torch.linspace(-math.pi, math.pi, 2000)
|
||||
y = torch.sin(x)
|
||||
|
||||
# Construct our model by instantiating the class defined above
|
||||
model = DynamicNet()
|
||||
|
||||
# Construct our loss function and an Optimizer. Training this strange model with
|
||||
# vanilla stochastic gradient descent is tough, so we use momentum
|
||||
criterion = torch.nn.MSELoss(reduction='sum')
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=1e-8, momentum=0.9)
|
||||
for t in range(60000):
|
||||
# Forward pass: Compute predicted y by passing x to the model
|
||||
y_pred = model(x)
|
||||
|
||||
# Compute and print loss
|
||||
loss = criterion(y_pred, y)
|
||||
if t % 2000 == 1999:
|
||||
print(t, loss.item())
|
||||
|
||||
# Zero gradients, perform a backward pass, and update the weights.
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
# print(torch.liveness_info())
|
||||
|
||||
print(f'Result: {model.string()}')
|
|
@ -1,106 +0,0 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
# from jtorch.utils import DataLoader
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision import datasets
|
||||
from torchvision.transforms import ToTensor
|
||||
|
||||
# Download training data from open datasets.
|
||||
training_data = datasets.FashionMNIST(
|
||||
root="data",
|
||||
train=True,
|
||||
download=True,
|
||||
transform=ToTensor(),
|
||||
)
|
||||
|
||||
# Download test data from open datasets.
|
||||
test_data = datasets.FashionMNIST(
|
||||
root="data",
|
||||
train=False,
|
||||
download=True,
|
||||
transform=ToTensor(),
|
||||
)
|
||||
|
||||
batch_size = 64
|
||||
|
||||
# Create data loaders.
|
||||
train_dataloader = DataLoader(training_data, batch_size=batch_size)
|
||||
test_dataloader = DataLoader(test_data, batch_size=batch_size)
|
||||
|
||||
print(len(train_dataloader))
|
||||
for X, y in test_dataloader:
|
||||
print(f"Shape of X [N, C, H, W]: {X.shape}")
|
||||
print(f"Shape of y: {y.shape} {y.dtype}")
|
||||
break
|
||||
|
||||
# Get cpu or gpu device for training.
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(f"Using {device} device")
|
||||
|
||||
# Define model
|
||||
class NeuralNetwork(nn.Module):
|
||||
def __init__(self):
|
||||
super(NeuralNetwork, self).__init__()
|
||||
self.flatten = nn.Flatten()
|
||||
self.linear_relu_stack = nn.Sequential(
|
||||
nn.Linear(28*28, 512),
|
||||
nn.ReLU(),
|
||||
nn.Linear(512, 512),
|
||||
nn.ReLU(),
|
||||
nn.Linear(512, 10)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.flatten(x)
|
||||
logits = self.linear_relu_stack(x)
|
||||
return logits
|
||||
|
||||
model = NeuralNetwork().to(device)
|
||||
print(model)
|
||||
|
||||
|
||||
loss_fn = nn.CrossEntropyLoss()
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
|
||||
|
||||
def train(dataloader, model, loss_fn, optimizer):
|
||||
size = len(dataloader.dataset)
|
||||
model.train()
|
||||
for batch, (X, y) in enumerate(dataloader):
|
||||
X, y = X.to(device), y.to(device)
|
||||
|
||||
# Compute prediction error
|
||||
pred = model(X)
|
||||
loss = loss_fn(pred, y)
|
||||
|
||||
# Backpropagation
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if batch % 100 == 0:
|
||||
loss, current = loss.item(), batch * len(X)
|
||||
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
|
||||
|
||||
def test(dataloader, model, loss_fn):
|
||||
size = len(dataloader.dataset)
|
||||
num_batches = len(dataloader)
|
||||
model.eval()
|
||||
test_loss, correct = 0, 0
|
||||
with torch.no_grad():
|
||||
for X, y in dataloader:
|
||||
X, y = X.to(device), y.to(device)
|
||||
pred = model(X)
|
||||
test_loss += loss_fn(pred, y).item()
|
||||
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
|
||||
test_loss /= num_batches
|
||||
correct /= size
|
||||
print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
|
||||
|
||||
|
||||
epochs = 5
|
||||
test(test_dataloader, model, loss_fn)
|
||||
for t in range(epochs):
|
||||
print(f"Epoch {t+1}\n-------------------------------")
|
||||
train(train_dataloader, model, loss_fn, optimizer)
|
||||
test(test_dataloader, model, loss_fn)
|
||||
print("Done!")
|
|
@ -1,5 +0,0 @@
|
|||
cpp_extension = None
|
||||
_flatten_dense_tensors = None
|
||||
_unflatten_dense_tensors = None
|
||||
|
||||
tensorboard = None
|
|
@ -1,3 +0,0 @@
|
|||
#TODO: Implement this
|
||||
_register_pytree_node = None
|
||||
_dict_flatten = None
|
|
@ -1,8 +0,0 @@
|
|||
detach_variable = None
|
||||
|
||||
|
||||
def checkpoint(
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
pass
|
|
@ -1,137 +0,0 @@
|
|||
import jittor as jt
|
||||
import jittor.dataset
|
||||
from jittor.dataset import Dataset as JDataset
|
||||
|
||||
from collections import namedtuple
|
||||
from typing import Any, Callable, Iterable, Optional, Sequence, Union
|
||||
|
||||
|
||||
class Dataset:
|
||||
def __getitem__(self, index):
|
||||
raise NotImplementedError
|
||||
|
||||
class IterableDataset:
|
||||
def __iter__(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class DataLoader(JDataset):
|
||||
def __init__(self, dataset,
|
||||
batch_size: Optional[int] = 1,
|
||||
shuffle: Optional[bool] = False,
|
||||
sampler = None,
|
||||
batch_sampler = None,
|
||||
num_workers: int = 0,
|
||||
collate_fn = None,
|
||||
pin_memory: bool = False,
|
||||
drop_last: bool = False,
|
||||
timeout: float = 0,
|
||||
worker_init_fn = None,
|
||||
multiprocessing_context=None,
|
||||
generator=None,
|
||||
*, prefetch_factor: int = 2,
|
||||
persistent_workers: bool = False,
|
||||
pin_memory_device: str = "") -> None:
|
||||
super().__init__(batch_size=batch_size,
|
||||
shuffle=shuffle,
|
||||
num_workers=num_workers,
|
||||
drop_last=drop_last)
|
||||
|
||||
unsupported_kwargs = {
|
||||
"batch_sampler": batch_sampler,
|
||||
"pin_memory": pin_memory,
|
||||
"timeout": timeout,
|
||||
"worker_init_fn": worker_init_fn,
|
||||
"multiprocessing_context": multiprocessing_context,
|
||||
"generator": generator,
|
||||
"persistent_workers": persistent_workers,
|
||||
"pin_memory_device": pin_memory_device
|
||||
}
|
||||
for kwarg, value in unsupported_kwargs.items():
|
||||
if value:
|
||||
jt.LOG.w(f"Not implemented Dataloader kwarg: {kwarg}")
|
||||
|
||||
self.dataset = dataset
|
||||
self.collate_fn = collate_fn
|
||||
self.sampler = sampler
|
||||
|
||||
if not isinstance(dataset, IterableDataset):
|
||||
self.total_len = len(dataset)
|
||||
else:
|
||||
# TODO: support multiple worker for iterable dataset
|
||||
assert(num_workers == 0)
|
||||
|
||||
def collate_batch(self, batch):
|
||||
if self.collate_fn is not None:
|
||||
return self.collate_fn(batch)
|
||||
else:
|
||||
return super().collate_batch(batch)
|
||||
|
||||
def __getitem__(self, i):
|
||||
return self.dataset[i]
|
||||
|
||||
def __iter__(self):
|
||||
if isinstance(self.dataset, IterableDataset):
|
||||
return self.inner_iter()
|
||||
else:
|
||||
return super().__iter__()
|
||||
|
||||
def inner_iter(self):
|
||||
current_batch = []
|
||||
|
||||
if jt.world_size > 1:
|
||||
assert self.batch_size % jt.world_size == 0, \
|
||||
f"IterableDataset does not support a batch size ({self.batch_size}) that is not evenly divisible by the number of processes f{jt.world_size}"
|
||||
real_batch_size = int(self.batch_size / jt.world_size)
|
||||
else:
|
||||
real_batch_size = self.batch_size
|
||||
|
||||
for element in self.dataset:
|
||||
current_batch.append(element)
|
||||
|
||||
if len(current_batch) == real_batch_size:
|
||||
current_batch = self.collate_batch(current_batch)
|
||||
current_batch = self.to_jittor(current_batch)
|
||||
yield current_batch
|
||||
current_batch = []
|
||||
|
||||
if not self.drop_last and len(current_batch) > 0:
|
||||
current_batch = self.collate_batch(current_batch)
|
||||
yield self.to_jittor(current_batch)
|
||||
|
||||
# def get_worker_info():
|
||||
# # always return the fake worker info
|
||||
# return namedtuple('WorkerInfo', 'id num_workers')(0, 1)
|
||||
|
||||
# class RandomSampler(jt.dataset.RandomSampler):
|
||||
# def __init__(self, dataset, generator=None, **kwargs):
|
||||
# super().__init__(dataset, **kwargs)
|
||||
|
||||
# def __iter__(self):
|
||||
# if getattr(self.dataset, "support_random_access", True):
|
||||
# return super().__iter__()
|
||||
# else:
|
||||
# self.dataset.shuffle()
|
||||
# return iter(range(self.dataset.__real_len__() if hasattr(self.dataset,"__real_len__") else self.dataset.__len__()))
|
||||
|
||||
# class DistributedSampler(jt.dataset.Sampler):
|
||||
# def __init__(self, sampler: RandomSampler):
|
||||
# assert(isinstance(sampler, RandomSampler))
|
||||
# self.sampler = sampler
|
||||
|
||||
# def set_epoch(self, epoch: int):
|
||||
# ### do nothing, let jittor's inner dataset handle
|
||||
# pass
|
||||
|
||||
# def __iter__(self):
|
||||
# return self.sampler.__iter__()
|
||||
|
||||
# def __len__(self):
|
||||
# return self.sampler.__len__()
|
||||
|
||||
# BatchSampler = jt.dataset.BatchSampler
|
||||
# Sampler = jt.dataset.Sampler
|
||||
# SequentialSampler = jt.dataset.SequentialSampler
|
||||
# SubsetRandomSampler = jt.dataset.SubsetRandomSampler
|
||||
|
||||
# TensorDataset = Dataset
|
|
@ -1,9 +0,0 @@
|
|||
from typing import Callable, Union
|
||||
Dtype = Union[Callable, str]
|
||||
|
||||
def get_string_dtype(dtype):
|
||||
if callable(dtype):
|
||||
dtype = dtype.__name__
|
||||
if not isinstance(dtype, str):
|
||||
raise ValueError(f"dtype is expected to be str, python type function, or jittor type function, but got {dtype}.")
|
||||
return dtype
|
|
@ -1,34 +0,0 @@
|
|||
import os
|
||||
import glob
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
home_path = os.path.join(os.path.dirname(__file__), "..", "..", "..")
|
||||
home_path = os.path.abspath(home_path)
|
||||
|
||||
def callback(func, path, exc_info):
|
||||
print(f"remove \"{path}\" failed.")
|
||||
|
||||
def rmtree(path):
|
||||
if os.path.isdir(path):
|
||||
print(f"remove \"{path}\" recursive.")
|
||||
shutil.rmtree(path, onerror=callback)
|
||||
|
||||
def remove_tmpfile():
|
||||
dist_file = home_path+"/dist"
|
||||
egg_file = glob.glob(home_path+"/**/*egg-info")
|
||||
rmtree(dist_file)
|
||||
for e in egg_file:
|
||||
rmtree(e)
|
||||
|
||||
def run_cmd(cmd):
|
||||
print("[CMD]", cmd)
|
||||
assert os.system(cmd)==0
|
||||
|
||||
os.chdir(home_path)
|
||||
remove_tmpfile()
|
||||
|
||||
run_cmd(f"{sys.executable} ./setup.py sdist")
|
||||
run_cmd(f"{sys.executable} -m twine upload dist/*")
|
||||
|
||||
remove_tmpfile()
|
|
@ -1,46 +0,0 @@
|
|||
import importlib.machinery
|
||||
import os
|
||||
|
||||
|
||||
def _download_file_from_remote_location(fpath: str, url: str) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def _is_remote_location_available() -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _get_extension_path(lib_name):
|
||||
|
||||
lib_dir = os.path.dirname(__file__)
|
||||
if os.name == "nt":
|
||||
# Register the main torchvision library location on the default DLL path
|
||||
import ctypes
|
||||
import sys
|
||||
|
||||
kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True)
|
||||
with_load_library_flags = hasattr(kernel32, "AddDllDirectory")
|
||||
prev_error_mode = kernel32.SetErrorMode(0x0001)
|
||||
|
||||
if with_load_library_flags:
|
||||
kernel32.AddDllDirectory.restype = ctypes.c_void_p
|
||||
|
||||
if sys.version_info >= (3, 8):
|
||||
os.add_dll_directory(lib_dir)
|
||||
elif with_load_library_flags:
|
||||
res = kernel32.AddDllDirectory(lib_dir)
|
||||
if res is None:
|
||||
err = ctypes.WinError(ctypes.get_last_error())
|
||||
err.strerror += f' Error adding "{lib_dir}" to the DLL directories.'
|
||||
raise err
|
||||
|
||||
kernel32.SetErrorMode(prev_error_mode)
|
||||
|
||||
loader_details = (importlib.machinery.ExtensionFileLoader, importlib.machinery.EXTENSION_SUFFIXES)
|
||||
|
||||
extfinder = importlib.machinery.FileFinder(lib_dir, loader_details)
|
||||
ext_specs = extfinder.find_spec(lib_name)
|
||||
if ext_specs is None:
|
||||
raise ImportError
|
||||
|
||||
return ext_specs.origin
|
|
@ -1,9 +0,0 @@
|
|||
from .mnist import EMNIST, FashionMNIST, KMNIST, MNIST, QMNIST
|
||||
|
||||
__all__ = (
|
||||
"EMNIST",
|
||||
"FashionMNIST",
|
||||
"QMNIST",
|
||||
"MNIST",
|
||||
"KMNIST",
|
||||
)
|
|
@ -1,558 +0,0 @@
|
|||
import codecs
|
||||
import os
|
||||
import os.path
|
||||
import shutil
|
||||
import string
|
||||
import sys
|
||||
import warnings
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
from urllib.error import URLError
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from .utils import check_integrity, download_and_extract_archive, extract_archive, verify_str_arg
|
||||
from .vision import VisionDataset
|
||||
|
||||
|
||||
class MNIST(VisionDataset):
|
||||
"""`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.
|
||||
|
||||
Args:
|
||||
root (string): Root directory of dataset where ``MNIST/raw/train-images-idx3-ubyte``
|
||||
and ``MNIST/raw/t10k-images-idx3-ubyte`` exist.
|
||||
train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``,
|
||||
otherwise from ``t10k-images-idx3-ubyte``.
|
||||
download (bool, optional): If True, downloads the dataset from the internet and
|
||||
puts it in root directory. If dataset is already downloaded, it is not
|
||||
downloaded again.
|
||||
transform (callable, optional): A function/transform that takes in an PIL image
|
||||
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
||||
target_transform (callable, optional): A function/transform that takes in the
|
||||
target and transforms it.
|
||||
"""
|
||||
|
||||
mirrors = [
|
||||
"http://yann.lecun.com/exdb/mnist/",
|
||||
"https://ossci-datasets.s3.amazonaws.com/mnist/",
|
||||
]
|
||||
|
||||
resources = [
|
||||
("train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"),
|
||||
("train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"),
|
||||
("t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"),
|
||||
("t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c"),
|
||||
]
|
||||
|
||||
training_file = "training.pt"
|
||||
test_file = "test.pt"
|
||||
classes = [
|
||||
"0 - zero",
|
||||
"1 - one",
|
||||
"2 - two",
|
||||
"3 - three",
|
||||
"4 - four",
|
||||
"5 - five",
|
||||
"6 - six",
|
||||
"7 - seven",
|
||||
"8 - eight",
|
||||
"9 - nine",
|
||||
]
|
||||
|
||||
@property
|
||||
def train_labels(self):
|
||||
warnings.warn("train_labels has been renamed targets")
|
||||
return self.targets
|
||||
|
||||
@property
|
||||
def test_labels(self):
|
||||
warnings.warn("test_labels has been renamed targets")
|
||||
return self.targets
|
||||
|
||||
@property
|
||||
def train_data(self):
|
||||
warnings.warn("train_data has been renamed data")
|
||||
return self.data
|
||||
|
||||
@property
|
||||
def test_data(self):
|
||||
warnings.warn("test_data has been renamed data")
|
||||
return self.data
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: str,
|
||||
train: bool = True,
|
||||
transform: Optional[Callable] = None,
|
||||
target_transform: Optional[Callable] = None,
|
||||
download: bool = False,
|
||||
) -> None:
|
||||
super().__init__(root, transform=transform, target_transform=target_transform)
|
||||
self.train = train # training set or test set
|
||||
|
||||
if self._check_legacy_exist():
|
||||
self.data, self.targets = self._load_legacy_data()
|
||||
return
|
||||
|
||||
if download:
|
||||
self.download()
|
||||
|
||||
if not self._check_exists():
|
||||
raise RuntimeError("Dataset not found. You can use download=True to download it")
|
||||
|
||||
self.data, self.targets = self._load_data()
|
||||
|
||||
def _check_legacy_exist(self):
|
||||
processed_folder_exists = os.path.exists(self.processed_folder)
|
||||
if not processed_folder_exists:
|
||||
return False
|
||||
|
||||
return all(
|
||||
check_integrity(os.path.join(self.processed_folder, file)) for file in (self.training_file, self.test_file)
|
||||
)
|
||||
|
||||
def _load_legacy_data(self):
|
||||
# This is for BC only. We no longer cache the data in a custom binary, but simply read from the raw data
|
||||
# directly.
|
||||
data_file = self.training_file if self.train else self.test_file
|
||||
return torch.load(os.path.join(self.processed_folder, data_file))
|
||||
|
||||
def _load_data(self):
|
||||
image_file = f"{'train' if self.train else 't10k'}-images-idx3-ubyte"
|
||||
data = read_image_file(os.path.join(self.raw_folder, image_file))
|
||||
|
||||
label_file = f"{'train' if self.train else 't10k'}-labels-idx1-ubyte"
|
||||
targets = read_label_file(os.path.join(self.raw_folder, label_file))
|
||||
|
||||
return data, targets
|
||||
|
||||
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
||||
"""
|
||||
Args:
|
||||
index (int): Index
|
||||
|
||||
Returns:
|
||||
tuple: (image, target) where target is index of the target class.
|
||||
"""
|
||||
img, target = self.data[index], int(self.targets[index])
|
||||
|
||||
# doing this so that it is consistent with all other datasets
|
||||
# to return a PIL Image
|
||||
img = Image.fromarray(img.numpy(), mode="L")
|
||||
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
|
||||
if self.target_transform is not None:
|
||||
target = self.target_transform(target)
|
||||
|
||||
return img, target
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.data)
|
||||
|
||||
@property
|
||||
def raw_folder(self) -> str:
|
||||
return os.path.join(self.root, self.__class__.__name__, "raw")
|
||||
|
||||
@property
|
||||
def processed_folder(self) -> str:
|
||||
return os.path.join(self.root, self.__class__.__name__, "processed")
|
||||
|
||||
@property
|
||||
def class_to_idx(self) -> Dict[str, int]:
|
||||
return {_class: i for i, _class in enumerate(self.classes)}
|
||||
|
||||
def _check_exists(self) -> bool:
|
||||
return all(
|
||||
check_integrity(os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0]))
|
||||
for url, _ in self.resources
|
||||
)
|
||||
|
||||
def download(self) -> None:
|
||||
"""Download the MNIST data if it doesn't exist already."""
|
||||
|
||||
if self._check_exists():
|
||||
return
|
||||
|
||||
os.makedirs(self.raw_folder, exist_ok=True)
|
||||
|
||||
# download files
|
||||
for filename, md5 in self.resources:
|
||||
for mirror in self.mirrors:
|
||||
url = f"{mirror}{filename}"
|
||||
try:
|
||||
print(f"Downloading {url}")
|
||||
download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5)
|
||||
except URLError as error:
|
||||
print(f"Failed to download (trying next):\n{error}")
|
||||
continue
|
||||
finally:
|
||||
print()
|
||||
break
|
||||
else:
|
||||
raise RuntimeError(f"Error downloading {filename}")
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
split = "Train" if self.train is True else "Test"
|
||||
return f"Split: {split}"
|
||||
|
||||
|
||||
class FashionMNIST(MNIST):
|
||||
"""`Fashion-MNIST <https://github.com/zalandoresearch/fashion-mnist>`_ Dataset.
|
||||
|
||||
Args:
|
||||
root (string): Root directory of dataset where ``FashionMNIST/raw/train-images-idx3-ubyte``
|
||||
and ``FashionMNIST/raw/t10k-images-idx3-ubyte`` exist.
|
||||
train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``,
|
||||
otherwise from ``t10k-images-idx3-ubyte``.
|
||||
download (bool, optional): If True, downloads the dataset from the internet and
|
||||
puts it in root directory. If dataset is already downloaded, it is not
|
||||
downloaded again.
|
||||
transform (callable, optional): A function/transform that takes in an PIL image
|
||||
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
||||
target_transform (callable, optional): A function/transform that takes in the
|
||||
target and transforms it.
|
||||
"""
|
||||
|
||||
mirrors = ["http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/"]
|
||||
|
||||
resources = [
|
||||
("train-images-idx3-ubyte.gz", "8d4fb7e6c68d591d4c3dfef9ec88bf0d"),
|
||||
("train-labels-idx1-ubyte.gz", "25c81989df183df01b3e8a0aad5dffbe"),
|
||||
("t10k-images-idx3-ubyte.gz", "bef4ecab320f06d8554ea6380940ec79"),
|
||||
("t10k-labels-idx1-ubyte.gz", "bb300cfdad3c16e7a12a480ee83cd310"),
|
||||
]
|
||||
classes = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]
|
||||
|
||||
|
||||
class KMNIST(MNIST):
|
||||
"""`Kuzushiji-MNIST <https://github.com/rois-codh/kmnist>`_ Dataset.
|
||||
|
||||
Args:
|
||||
root (string): Root directory of dataset where ``KMNIST/raw/train-images-idx3-ubyte``
|
||||
and ``KMNIST/raw/t10k-images-idx3-ubyte`` exist.
|
||||
train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``,
|
||||
otherwise from ``t10k-images-idx3-ubyte``.
|
||||
download (bool, optional): If True, downloads the dataset from the internet and
|
||||
puts it in root directory. If dataset is already downloaded, it is not
|
||||
downloaded again.
|
||||
transform (callable, optional): A function/transform that takes in an PIL image
|
||||
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
||||
target_transform (callable, optional): A function/transform that takes in the
|
||||
target and transforms it.
|
||||
"""
|
||||
|
||||
mirrors = ["http://codh.rois.ac.jp/kmnist/dataset/kmnist/"]
|
||||
|
||||
resources = [
|
||||
("train-images-idx3-ubyte.gz", "bdb82020997e1d708af4cf47b453dcf7"),
|
||||
("train-labels-idx1-ubyte.gz", "e144d726b3acfaa3e44228e80efcd344"),
|
||||
("t10k-images-idx3-ubyte.gz", "5c965bf0a639b31b8f53240b1b52f4d7"),
|
||||
("t10k-labels-idx1-ubyte.gz", "7320c461ea6c1c855c0b718fb2a4b134"),
|
||||
]
|
||||
classes = ["o", "ki", "su", "tsu", "na", "ha", "ma", "ya", "re", "wo"]
|
||||
|
||||
|
||||
class EMNIST(MNIST):
|
||||
"""`EMNIST <https://www.westernsydney.edu.au/bens/home/reproducible_research/emnist>`_ Dataset.
|
||||
|
||||
Args:
|
||||
root (string): Root directory of dataset where ``EMNIST/raw/train-images-idx3-ubyte``
|
||||
and ``EMNIST/raw/t10k-images-idx3-ubyte`` exist.
|
||||
split (string): The dataset has 6 different splits: ``byclass``, ``bymerge``,
|
||||
``balanced``, ``letters``, ``digits`` and ``mnist``. This argument specifies
|
||||
which one to use.
|
||||
train (bool, optional): If True, creates dataset from ``training.pt``,
|
||||
otherwise from ``test.pt``.
|
||||
download (bool, optional): If True, downloads the dataset from the internet and
|
||||
puts it in root directory. If dataset is already downloaded, it is not
|
||||
downloaded again.
|
||||
transform (callable, optional): A function/transform that takes in an PIL image
|
||||
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
||||
target_transform (callable, optional): A function/transform that takes in the
|
||||
target and transforms it.
|
||||
"""
|
||||
|
||||
url = "https://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zip"
|
||||
md5 = "58c8d27c78d21e728a6bc7b3cc06412e"
|
||||
splits = ("byclass", "bymerge", "balanced", "letters", "digits", "mnist")
|
||||
# Merged Classes assumes Same structure for both uppercase and lowercase version
|
||||
_merged_classes = {"c", "i", "j", "k", "l", "m", "o", "p", "s", "u", "v", "w", "x", "y", "z"}
|
||||
_all_classes = set(string.digits + string.ascii_letters)
|
||||
classes_split_dict = {
|
||||
"byclass": sorted(list(_all_classes)),
|
||||
"bymerge": sorted(list(_all_classes - _merged_classes)),
|
||||
"balanced": sorted(list(_all_classes - _merged_classes)),
|
||||
"letters": ["N/A"] + list(string.ascii_lowercase),
|
||||
"digits": list(string.digits),
|
||||
"mnist": list(string.digits),
|
||||
}
|
||||
|
||||
def __init__(self, root: str, split: str, **kwargs: Any) -> None:
|
||||
self.split = verify_str_arg(split, "split", self.splits)
|
||||
self.training_file = self._training_file(split)
|
||||
self.test_file = self._test_file(split)
|
||||
super().__init__(root, **kwargs)
|
||||
self.classes = self.classes_split_dict[self.split]
|
||||
|
||||
@staticmethod
|
||||
def _training_file(split) -> str:
|
||||
return f"training_{split}.pt"
|
||||
|
||||
@staticmethod
|
||||
def _test_file(split) -> str:
|
||||
return f"test_{split}.pt"
|
||||
|
||||
@property
|
||||
def _file_prefix(self) -> str:
|
||||
return f"emnist-{self.split}-{'train' if self.train else 'test'}"
|
||||
|
||||
@property
|
||||
def images_file(self) -> str:
|
||||
return os.path.join(self.raw_folder, f"{self._file_prefix}-images-idx3-ubyte")
|
||||
|
||||
@property
|
||||
def labels_file(self) -> str:
|
||||
return os.path.join(self.raw_folder, f"{self._file_prefix}-labels-idx1-ubyte")
|
||||
|
||||
def _load_data(self):
|
||||
return read_image_file(self.images_file), read_label_file(self.labels_file)
|
||||
|
||||
def _check_exists(self) -> bool:
|
||||
return all(check_integrity(file) for file in (self.images_file, self.labels_file))
|
||||
|
||||
def download(self) -> None:
|
||||
"""Download the EMNIST data if it doesn't exist already."""
|
||||
|
||||
if self._check_exists():
|
||||
return
|
||||
|
||||
os.makedirs(self.raw_folder, exist_ok=True)
|
||||
|
||||
download_and_extract_archive(self.url, download_root=self.raw_folder, md5=self.md5)
|
||||
gzip_folder = os.path.join(self.raw_folder, "gzip")
|
||||
for gzip_file in os.listdir(gzip_folder):
|
||||
if gzip_file.endswith(".gz"):
|
||||
extract_archive(os.path.join(gzip_folder, gzip_file), self.raw_folder)
|
||||
shutil.rmtree(gzip_folder)
|
||||
|
||||
|
||||
class QMNIST(MNIST):
|
||||
"""`QMNIST <https://github.com/facebookresearch/qmnist>`_ Dataset.
|
||||
|
||||
Args:
|
||||
root (string): Root directory of dataset whose ``raw``
|
||||
subdir contains binary files of the datasets.
|
||||
what (string,optional): Can be 'train', 'test', 'test10k',
|
||||
'test50k', or 'nist' for respectively the mnist compatible
|
||||
training set, the 60k qmnist testing set, the 10k qmnist
|
||||
examples that match the mnist testing set, the 50k
|
||||
remaining qmnist testing examples, or all the nist
|
||||
digits. The default is to select 'train' or 'test'
|
||||
according to the compatibility argument 'train'.
|
||||
compat (bool,optional): A boolean that says whether the target
|
||||
for each example is class number (for compatibility with
|
||||
the MNIST dataloader) or a torch vector containing the
|
||||
full qmnist information. Default=True.
|
||||
download (bool, optional): If True, downloads the dataset from
|
||||
the internet and puts it in root directory. If dataset is
|
||||
already downloaded, it is not downloaded again.
|
||||
transform (callable, optional): A function/transform that
|
||||
takes in an PIL image and returns a transformed
|
||||
version. E.g, ``transforms.RandomCrop``
|
||||
target_transform (callable, optional): A function/transform
|
||||
that takes in the target and transforms it.
|
||||
train (bool,optional,compatibility): When argument 'what' is
|
||||
not specified, this boolean decides whether to load the
|
||||
training set ot the testing set. Default: True.
|
||||
"""
|
||||
|
||||
subsets = {"train": "train", "test": "test", "test10k": "test", "test50k": "test", "nist": "nist"}
|
||||
resources: Dict[str, List[Tuple[str, str]]] = { # type: ignore[assignment]
|
||||
"train": [
|
||||
(
|
||||
"https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-images-idx3-ubyte.gz",
|
||||
"ed72d4157d28c017586c42bc6afe6370",
|
||||
),
|
||||
(
|
||||
"https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-labels-idx2-int.gz",
|
||||
"0058f8dd561b90ffdd0f734c6a30e5e4",
|
||||
),
|
||||
],
|
||||
"test": [
|
||||
(
|
||||
"https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-images-idx3-ubyte.gz",
|
||||
"1394631089c404de565df7b7aeaf9412",
|
||||
),
|
||||
(
|
||||
"https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-labels-idx2-int.gz",
|
||||
"5b5b05890a5e13444e108efe57b788aa",
|
||||
),
|
||||
],
|
||||
"nist": [
|
||||
(
|
||||
"https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-images-idx3-ubyte.xz",
|
||||
"7f124b3b8ab81486c9d8c2749c17f834",
|
||||
),
|
||||
(
|
||||
"https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-labels-idx2-int.xz",
|
||||
"5ed0e788978e45d4a8bd4b7caec3d79d",
|
||||
),
|
||||
],
|
||||
}
|
||||
classes = [
|
||||
"0 - zero",
|
||||
"1 - one",
|
||||
"2 - two",
|
||||
"3 - three",
|
||||
"4 - four",
|
||||
"5 - five",
|
||||
"6 - six",
|
||||
"7 - seven",
|
||||
"8 - eight",
|
||||
"9 - nine",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self, root: str, what: Optional[str] = None, compat: bool = True, train: bool = True, **kwargs: Any
|
||||
) -> None:
|
||||
if what is None:
|
||||
what = "train" if train else "test"
|
||||
self.what = verify_str_arg(what, "what", tuple(self.subsets.keys()))
|
||||
self.compat = compat
|
||||
self.data_file = what + ".pt"
|
||||
self.training_file = self.data_file
|
||||
self.test_file = self.data_file
|
||||
super().__init__(root, train, **kwargs)
|
||||
|
||||
@property
|
||||
def images_file(self) -> str:
|
||||
(url, _), _ = self.resources[self.subsets[self.what]]
|
||||
return os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0])
|
||||
|
||||
@property
|
||||
def labels_file(self) -> str:
|
||||
_, (url, _) = self.resources[self.subsets[self.what]]
|
||||
return os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0])
|
||||
|
||||
def _check_exists(self) -> bool:
|
||||
return all(check_integrity(file) for file in (self.images_file, self.labels_file))
|
||||
|
||||
def _load_data(self):
|
||||
data = read_sn3_pascalvincent_tensor(self.images_file)
|
||||
if data.dtype != torch.uint8:
|
||||
raise TypeError(f"data should be of dtype torch.uint8 instead of {data.dtype}")
|
||||
if data.ndimension() != 3:
|
||||
raise ValueError("data should have 3 dimensions instead of {data.ndimension()}")
|
||||
|
||||
targets = read_sn3_pascalvincent_tensor(self.labels_file).long()
|
||||
if targets.ndimension() != 2:
|
||||
raise ValueError(f"targets should have 2 dimensions instead of {targets.ndimension()}")
|
||||
|
||||
if self.what == "test10k":
|
||||
data = data[0:10000, :, :].clone()
|
||||
targets = targets[0:10000, :].clone()
|
||||
elif self.what == "test50k":
|
||||
data = data[10000:, :, :].clone()
|
||||
targets = targets[10000:, :].clone()
|
||||
|
||||
return data, targets
|
||||
|
||||
def download(self) -> None:
|
||||
"""Download the QMNIST data if it doesn't exist already.
|
||||
Note that we only download what has been asked for (argument 'what').
|
||||
"""
|
||||
if self._check_exists():
|
||||
return
|
||||
|
||||
os.makedirs(self.raw_folder, exist_ok=True)
|
||||
split = self.resources[self.subsets[self.what]]
|
||||
|
||||
for url, md5 in split:
|
||||
download_and_extract_archive(url, self.raw_folder, md5=md5)
|
||||
|
||||
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
||||
# redefined to handle the compat flag
|
||||
img, target = self.data[index], self.targets[index]
|
||||
img = Image.fromarray(img.numpy(), mode="L")
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
if self.compat:
|
||||
target = int(target[0])
|
||||
if self.target_transform is not None:
|
||||
target = self.target_transform(target)
|
||||
return img, target
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f"Split: {self.what}"
|
||||
|
||||
|
||||
def get_int(b: bytes) -> int:
|
||||
return int(codecs.encode(b, "hex"), 16)
|
||||
|
||||
|
||||
SN3_PASCALVINCENT_BITSMAP = {
|
||||
8: torch.uint8,
|
||||
9: torch.int8,
|
||||
11: torch.int16,
|
||||
12: torch.int32,
|
||||
13: torch.float32,
|
||||
14: torch.float64,
|
||||
}
|
||||
|
||||
TORCH_TYPE_BITS = {
|
||||
torch.uint8: 8,
|
||||
torch.int8: 8,
|
||||
torch.int16: 16,
|
||||
torch.int32: 32,
|
||||
torch.float32: 32,
|
||||
torch.float64: 64,
|
||||
}
|
||||
|
||||
|
||||
def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tensor:
|
||||
"""Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh').
|
||||
Argument may be a filename, compressed filename, or file object.
|
||||
"""
|
||||
# read
|
||||
with open(path, "rb") as f:
|
||||
data = f.read()
|
||||
# parse
|
||||
magic = get_int(data[0:4])
|
||||
nd = magic % 256
|
||||
ty = magic // 256
|
||||
assert 1 <= nd <= 3
|
||||
assert 8 <= ty <= 14
|
||||
torch_type = SN3_PASCALVINCENT_BITSMAP[ty]
|
||||
s = [get_int(data[4 * (i + 1) : 4 * (i + 2)]) for i in range(nd)]
|
||||
|
||||
num_bytes_per_value = TORCH_TYPE_BITS[torch_type] // 8
|
||||
# The MNIST format uses the big endian byte order. If the system uses little endian byte order by default,
|
||||
# we need to reverse the bytes before we can read them with torch.frombuffer().
|
||||
needs_byte_reversal = sys.byteorder == "little" and num_bytes_per_value > 1
|
||||
parsed = torch.frombuffer(bytearray(data), dtype=torch_type, offset=(4 * (nd + 1)))
|
||||
if needs_byte_reversal:
|
||||
parsed = parsed.flip(0)
|
||||
|
||||
assert parsed.shape[0] == np.prod(s) or not strict
|
||||
return parsed.view(*s)
|
||||
|
||||
|
||||
def read_label_file(path: str) -> torch.Tensor:
|
||||
x = read_sn3_pascalvincent_tensor(path, strict=False)
|
||||
if x.dtype != torch.uint8:
|
||||
raise TypeError(f"x should be of dtype torch.uint8 instead of {x.dtype}")
|
||||
if x.ndimension() != 1:
|
||||
raise ValueError(f"x should have 1 dimension instead of {x.ndimension()}")
|
||||
return x.long()
|
||||
|
||||
|
||||
def read_image_file(path: str) -> torch.Tensor:
|
||||
x = read_sn3_pascalvincent_tensor(path, strict=False)
|
||||
if x.dtype != torch.uint8:
|
||||
raise TypeError(f"x should be of dtype torch.uint8 instead of {x.dtype}")
|
||||
if x.ndimension() != 3:
|
||||
raise ValueError(f"x should have 3 dimension instead of {x.ndimension()}")
|
||||
return x
|
|
@ -1,522 +0,0 @@
|
|||
import bz2
|
||||
import contextlib
|
||||
import gzip
|
||||
import hashlib
|
||||
import itertools
|
||||
import lzma
|
||||
import os
|
||||
import os.path
|
||||
import pathlib
|
||||
import re
|
||||
import sys
|
||||
import tarfile
|
||||
import urllib
|
||||
import urllib.error
|
||||
import urllib.request
|
||||
import warnings
|
||||
import zipfile
|
||||
from typing import Any, Callable, Dict, IO, Iterable, Iterator, List, Optional, Tuple, TypeVar
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from .._internally_replaced_utils import _download_file_from_remote_location, _is_remote_location_available
|
||||
|
||||
USER_AGENT = "pytorch/vision"
|
||||
|
||||
|
||||
def _save_response_content(
|
||||
content: Iterator[bytes],
|
||||
destination: str,
|
||||
length: Optional[int] = None,
|
||||
) -> None:
|
||||
with open(destination, "wb") as fh, tqdm(total=length) as pbar:
|
||||
for chunk in content:
|
||||
# filter out keep-alive new chunks
|
||||
if not chunk:
|
||||
continue
|
||||
|
||||
fh.write(chunk)
|
||||
pbar.update(len(chunk))
|
||||
|
||||
|
||||
def _urlretrieve(url: str, filename: str, chunk_size: int = 1024 * 32) -> None:
|
||||
with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response:
|
||||
_save_response_content(iter(lambda: response.read(chunk_size), b""), filename, length=response.length)
|
||||
|
||||
|
||||
def gen_bar_updater() -> Callable[[int, int, int], None]:
|
||||
warnings.warn("The function `gen_bar_update` is deprecated since 0.13 and will be removed in 0.15.")
|
||||
pbar = tqdm(total=None)
|
||||
|
||||
def bar_update(count, block_size, total_size):
|
||||
if pbar.total is None and total_size:
|
||||
pbar.total = total_size
|
||||
progress_bytes = count * block_size
|
||||
pbar.update(progress_bytes - pbar.n)
|
||||
|
||||
return bar_update
|
||||
|
||||
|
||||
def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024) -> str:
|
||||
# Setting the `usedforsecurity` flag does not change anything about the functionality, but indicates that we are
|
||||
# not using the MD5 checksum for cryptography. This enables its usage in restricted environments like FIPS. Without
|
||||
# it torchvision.datasets is unusable in these environments since we perform a MD5 check everywhere.
|
||||
if sys.version_info >= (3, 9):
|
||||
md5 = hashlib.md5(usedforsecurity=False)
|
||||
else:
|
||||
md5 = hashlib.md5()
|
||||
with open(fpath, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(chunk_size), b""):
|
||||
md5.update(chunk)
|
||||
return md5.hexdigest()
|
||||
|
||||
|
||||
def check_md5(fpath: str, md5: str, **kwargs: Any) -> bool:
|
||||
return md5 == calculate_md5(fpath, **kwargs)
|
||||
|
||||
|
||||
def check_integrity(fpath: str, md5: Optional[str] = None) -> bool:
|
||||
if not os.path.isfile(fpath):
|
||||
return False
|
||||
if md5 is None:
|
||||
return True
|
||||
return check_md5(fpath, md5)
|
||||
|
||||
|
||||
def _get_redirect_url(url: str, max_hops: int = 3) -> str:
|
||||
initial_url = url
|
||||
headers = {"Method": "HEAD", "User-Agent": USER_AGENT}
|
||||
|
||||
for _ in range(max_hops + 1):
|
||||
with urllib.request.urlopen(urllib.request.Request(url, headers=headers)) as response:
|
||||
if response.url == url or response.url is None:
|
||||
return url
|
||||
|
||||
url = response.url
|
||||
else:
|
||||
raise RecursionError(
|
||||
f"Request to {initial_url} exceeded {max_hops} redirects. The last redirect points to {url}."
|
||||
)
|
||||
|
||||
|
||||
def _get_google_drive_file_id(url: str) -> Optional[str]:
|
||||
parts = urlparse(url)
|
||||
|
||||
if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None:
|
||||
return None
|
||||
|
||||
match = re.match(r"/file/d/(?P<id>[^/]*)", parts.path)
|
||||
if match is None:
|
||||
return None
|
||||
|
||||
return match.group("id")
|
||||
|
||||
|
||||
def download_url(
|
||||
url: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None, max_redirect_hops: int = 3
|
||||
) -> None:
|
||||
"""Download a file from a url and place it in root.
|
||||
|
||||
Args:
|
||||
url (str): URL to download file from
|
||||
root (str): Directory to place downloaded file in
|
||||
filename (str, optional): Name to save the file under. If None, use the basename of the URL
|
||||
md5 (str, optional): MD5 checksum of the download. If None, do not check
|
||||
max_redirect_hops (int, optional): Maximum number of redirect hops allowed
|
||||
"""
|
||||
root = os.path.expanduser(root)
|
||||
if not filename:
|
||||
filename = os.path.basename(url)
|
||||
fpath = os.path.join(root, filename)
|
||||
|
||||
os.makedirs(root, exist_ok=True)
|
||||
|
||||
# check if file is already present locally
|
||||
if check_integrity(fpath, md5):
|
||||
print("Using downloaded and verified file: " + fpath)
|
||||
return
|
||||
|
||||
if _is_remote_location_available():
|
||||
_download_file_from_remote_location(fpath, url)
|
||||
else:
|
||||
# expand redirect chain if needed
|
||||
url = _get_redirect_url(url, max_hops=max_redirect_hops)
|
||||
|
||||
# check if file is located on Google Drive
|
||||
file_id = _get_google_drive_file_id(url)
|
||||
if file_id is not None:
|
||||
return download_file_from_google_drive(file_id, root, filename, md5)
|
||||
|
||||
# download the file
|
||||
try:
|
||||
print("Downloading " + url + " to " + fpath)
|
||||
_urlretrieve(url, fpath)
|
||||
except (urllib.error.URLError, OSError) as e: # type: ignore[attr-defined]
|
||||
if url[:5] == "https":
|
||||
url = url.replace("https:", "http:")
|
||||
print("Failed download. Trying https -> http instead. Downloading " + url + " to " + fpath)
|
||||
_urlretrieve(url, fpath)
|
||||
else:
|
||||
raise e
|
||||
|
||||
# check integrity of downloaded file
|
||||
if not check_integrity(fpath, md5):
|
||||
raise RuntimeError("File not found or corrupted.")
|
||||
|
||||
|
||||
def list_dir(root: str, prefix: bool = False) -> List[str]:
|
||||
"""List all directories at a given root
|
||||
|
||||
Args:
|
||||
root (str): Path to directory whose folders need to be listed
|
||||
prefix (bool, optional): If true, prepends the path to each result, otherwise
|
||||
only returns the name of the directories found
|
||||
"""
|
||||
root = os.path.expanduser(root)
|
||||
directories = [p for p in os.listdir(root) if os.path.isdir(os.path.join(root, p))]
|
||||
if prefix is True:
|
||||
directories = [os.path.join(root, d) for d in directories]
|
||||
return directories
|
||||
|
||||
|
||||
def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]:
|
||||
"""List all files ending with a suffix at a given root
|
||||
|
||||
Args:
|
||||
root (str): Path to directory whose folders need to be listed
|
||||
suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png').
|
||||
It uses the Python "str.endswith" method and is passed directly
|
||||
prefix (bool, optional): If true, prepends the path to each result, otherwise
|
||||
only returns the name of the files found
|
||||
"""
|
||||
root = os.path.expanduser(root)
|
||||
files = [p for p in os.listdir(root) if os.path.isfile(os.path.join(root, p)) and p.endswith(suffix)]
|
||||
if prefix is True:
|
||||
files = [os.path.join(root, d) for d in files]
|
||||
return files
|
||||
|
||||
|
||||
def _extract_gdrive_api_response(response, chunk_size: int = 32 * 1024) -> Tuple[bytes, Iterator[bytes]]:
|
||||
content = response.iter_content(chunk_size)
|
||||
first_chunk = None
|
||||
# filter out keep-alive new chunks
|
||||
while not first_chunk:
|
||||
first_chunk = next(content)
|
||||
content = itertools.chain([first_chunk], content)
|
||||
|
||||
try:
|
||||
match = re.search("<title>Google Drive - (?P<api_response>.+?)</title>", first_chunk.decode())
|
||||
api_response = match["api_response"] if match is not None else None
|
||||
except UnicodeDecodeError:
|
||||
api_response = None
|
||||
return api_response, content
|
||||
|
||||
|
||||
def download_file_from_google_drive(file_id: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None):
|
||||
"""Download a Google Drive file from and place it in root.
|
||||
|
||||
Args:
|
||||
file_id (str): id of file to be downloaded
|
||||
root (str): Directory to place downloaded file in
|
||||
filename (str, optional): Name to save the file under. If None, use the id of the file.
|
||||
md5 (str, optional): MD5 checksum of the download. If None, do not check
|
||||
"""
|
||||
# Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url
|
||||
|
||||
root = os.path.expanduser(root)
|
||||
if not filename:
|
||||
filename = file_id
|
||||
fpath = os.path.join(root, filename)
|
||||
|
||||
os.makedirs(root, exist_ok=True)
|
||||
|
||||
if check_integrity(fpath, md5):
|
||||
print(f"Using downloaded {'and verified ' if md5 else ''}file: {fpath}")
|
||||
return
|
||||
|
||||
url = "https://drive.google.com/uc"
|
||||
params = dict(id=file_id, export="download")
|
||||
with requests.Session() as session:
|
||||
response = session.get(url, params=params, stream=True)
|
||||
|
||||
for key, value in response.cookies.items():
|
||||
if key.startswith("download_warning"):
|
||||
token = value
|
||||
break
|
||||
else:
|
||||
api_response, content = _extract_gdrive_api_response(response)
|
||||
token = "t" if api_response == "Virus scan warning" else None
|
||||
|
||||
if token is not None:
|
||||
response = session.get(url, params=dict(params, confirm=token), stream=True)
|
||||
api_response, content = _extract_gdrive_api_response(response)
|
||||
|
||||
if api_response == "Quota exceeded":
|
||||
raise RuntimeError(
|
||||
f"The daily quota of the file {filename} is exceeded and it "
|
||||
f"can't be downloaded. This is a limitation of Google Drive "
|
||||
f"and can only be overcome by trying again later."
|
||||
)
|
||||
|
||||
_save_response_content(content, fpath)
|
||||
|
||||
# In case we deal with an unhandled GDrive API response, the file should be smaller than 10kB and contain only text
|
||||
if os.stat(fpath).st_size < 10 * 1024:
|
||||
with contextlib.suppress(UnicodeDecodeError), open(fpath) as fh:
|
||||
text = fh.read()
|
||||
# Regular expression to detect HTML. Copied from https://stackoverflow.com/a/70585604
|
||||
if re.search(r"</?\s*[a-z-][^>]*\s*>|(&(?:[\w\d]+|#\d+|#x[a-f\d]+);)", text):
|
||||
warnings.warn(
|
||||
f"We detected some HTML elements in the downloaded file. "
|
||||
f"This most likely means that the download triggered an unhandled API response by GDrive. "
|
||||
f"Please report this to torchvision at https://github.com/pytorch/vision/issues including "
|
||||
f"the response:\n\n{text}"
|
||||
)
|
||||
|
||||
if md5 and not check_md5(fpath, md5):
|
||||
raise RuntimeError(
|
||||
f"The MD5 checksum of the download file {fpath} does not match the one on record."
|
||||
f"Please delete the file and try again. "
|
||||
f"If the issue persists, please report this to torchvision at https://github.com/pytorch/vision/issues."
|
||||
)
|
||||
|
||||
|
||||
def _extract_tar(from_path: str, to_path: str, compression: Optional[str]) -> None:
|
||||
with tarfile.open(from_path, f"r:{compression[1:]}" if compression else "r") as tar:
|
||||
tar.extractall(to_path)
|
||||
|
||||
|
||||
_ZIP_COMPRESSION_MAP: Dict[str, int] = {
|
||||
".bz2": zipfile.ZIP_BZIP2,
|
||||
".xz": zipfile.ZIP_LZMA,
|
||||
}
|
||||
|
||||
|
||||
def _extract_zip(from_path: str, to_path: str, compression: Optional[str]) -> None:
|
||||
with zipfile.ZipFile(
|
||||
from_path, "r", compression=_ZIP_COMPRESSION_MAP[compression] if compression else zipfile.ZIP_STORED
|
||||
) as zip:
|
||||
zip.extractall(to_path)
|
||||
|
||||
|
||||
_ARCHIVE_EXTRACTORS: Dict[str, Callable[[str, str, Optional[str]], None]] = {
|
||||
".tar": _extract_tar,
|
||||
".zip": _extract_zip,
|
||||
}
|
||||
_COMPRESSED_FILE_OPENERS: Dict[str, Callable[..., IO]] = {
|
||||
".bz2": bz2.open,
|
||||
".gz": gzip.open,
|
||||
".xz": lzma.open,
|
||||
}
|
||||
_FILE_TYPE_ALIASES: Dict[str, Tuple[Optional[str], Optional[str]]] = {
|
||||
".tbz": (".tar", ".bz2"),
|
||||
".tbz2": (".tar", ".bz2"),
|
||||
".tgz": (".tar", ".gz"),
|
||||
}
|
||||
|
||||
|
||||
def _detect_file_type(file: str) -> Tuple[str, Optional[str], Optional[str]]:
|
||||
"""Detect the archive type and/or compression of a file.
|
||||
|
||||
Args:
|
||||
file (str): the filename
|
||||
|
||||
Returns:
|
||||
(tuple): tuple of suffix, archive type, and compression
|
||||
|
||||
Raises:
|
||||
RuntimeError: if file has no suffix or suffix is not supported
|
||||
"""
|
||||
suffixes = pathlib.Path(file).suffixes
|
||||
if not suffixes:
|
||||
raise RuntimeError(
|
||||
f"File '{file}' has no suffixes that could be used to detect the archive type and compression."
|
||||
)
|
||||
suffix = suffixes[-1]
|
||||
|
||||
# check if the suffix is a known alias
|
||||
if suffix in _FILE_TYPE_ALIASES:
|
||||
return (suffix, *_FILE_TYPE_ALIASES[suffix])
|
||||
|
||||
# check if the suffix is an archive type
|
||||
if suffix in _ARCHIVE_EXTRACTORS:
|
||||
return suffix, suffix, None
|
||||
|
||||
# check if the suffix is a compression
|
||||
if suffix in _COMPRESSED_FILE_OPENERS:
|
||||
# check for suffix hierarchy
|
||||
if len(suffixes) > 1:
|
||||
suffix2 = suffixes[-2]
|
||||
|
||||
# check if the suffix2 is an archive type
|
||||
if suffix2 in _ARCHIVE_EXTRACTORS:
|
||||
return suffix2 + suffix, suffix2, suffix
|
||||
|
||||
return suffix, None, suffix
|
||||
|
||||
valid_suffixes = sorted(set(_FILE_TYPE_ALIASES) | set(_ARCHIVE_EXTRACTORS) | set(_COMPRESSED_FILE_OPENERS))
|
||||
raise RuntimeError(f"Unknown compression or archive type: '{suffix}'.\nKnown suffixes are: '{valid_suffixes}'.")
|
||||
|
||||
|
||||
def _decompress(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str:
|
||||
r"""Decompress a file.
|
||||
|
||||
The compression is automatically detected from the file name.
|
||||
|
||||
Args:
|
||||
from_path (str): Path to the file to be decompressed.
|
||||
to_path (str): Path to the decompressed file. If omitted, ``from_path`` without compression extension is used.
|
||||
remove_finished (bool): If ``True``, remove the file after the extraction.
|
||||
|
||||
Returns:
|
||||
(str): Path to the decompressed file.
|
||||
"""
|
||||
suffix, archive_type, compression = _detect_file_type(from_path)
|
||||
if not compression:
|
||||
raise RuntimeError(f"Couldn't detect a compression from suffix {suffix}.")
|
||||
|
||||
if to_path is None:
|
||||
to_path = from_path.replace(suffix, archive_type if archive_type is not None else "")
|
||||
|
||||
# We don't need to check for a missing key here, since this was already done in _detect_file_type()
|
||||
compressed_file_opener = _COMPRESSED_FILE_OPENERS[compression]
|
||||
|
||||
with compressed_file_opener(from_path, "rb") as rfh, open(to_path, "wb") as wfh:
|
||||
wfh.write(rfh.read())
|
||||
|
||||
if remove_finished:
|
||||
os.remove(from_path)
|
||||
|
||||
return to_path
|
||||
|
||||
|
||||
def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str:
|
||||
"""Extract an archive.
|
||||
|
||||
The archive type and a possible compression is automatically detected from the file name. If the file is compressed
|
||||
but not an archive the call is dispatched to :func:`decompress`.
|
||||
|
||||
Args:
|
||||
from_path (str): Path to the file to be extracted.
|
||||
to_path (str): Path to the directory the file will be extracted to. If omitted, the directory of the file is
|
||||
used.
|
||||
remove_finished (bool): If ``True``, remove the file after the extraction.
|
||||
|
||||
Returns:
|
||||
(str): Path to the directory the file was extracted to.
|
||||
"""
|
||||
if to_path is None:
|
||||
to_path = os.path.dirname(from_path)
|
||||
|
||||
suffix, archive_type, compression = _detect_file_type(from_path)
|
||||
if not archive_type:
|
||||
return _decompress(
|
||||
from_path,
|
||||
os.path.join(to_path, os.path.basename(from_path).replace(suffix, "")),
|
||||
remove_finished=remove_finished,
|
||||
)
|
||||
|
||||
# We don't need to check for a missing key here, since this was already done in _detect_file_type()
|
||||
extractor = _ARCHIVE_EXTRACTORS[archive_type]
|
||||
|
||||
extractor(from_path, to_path, compression)
|
||||
if remove_finished:
|
||||
os.remove(from_path)
|
||||
|
||||
return to_path
|
||||
|
||||
|
||||
def download_and_extract_archive(
|
||||
url: str,
|
||||
download_root: str,
|
||||
extract_root: Optional[str] = None,
|
||||
filename: Optional[str] = None,
|
||||
md5: Optional[str] = None,
|
||||
remove_finished: bool = False,
|
||||
) -> None:
|
||||
download_root = os.path.expanduser(download_root)
|
||||
if extract_root is None:
|
||||
extract_root = download_root
|
||||
if not filename:
|
||||
filename = os.path.basename(url)
|
||||
|
||||
download_url(url, download_root, filename, md5)
|
||||
|
||||
archive = os.path.join(download_root, filename)
|
||||
print(f"Extracting {archive} to {extract_root}")
|
||||
extract_archive(archive, extract_root, remove_finished)
|
||||
|
||||
|
||||
def iterable_to_str(iterable: Iterable) -> str:
|
||||
return "'" + "', '".join([str(item) for item in iterable]) + "'"
|
||||
|
||||
|
||||
T = TypeVar("T", str, bytes)
|
||||
|
||||
|
||||
def verify_str_arg(
|
||||
value: T,
|
||||
arg: Optional[str] = None,
|
||||
valid_values: Optional[Iterable[T]] = None,
|
||||
custom_msg: Optional[str] = None,
|
||||
) -> T:
|
||||
if not isinstance(value, torch._six.string_classes):
|
||||
if arg is None:
|
||||
msg = "Expected type str, but got type {type}."
|
||||
else:
|
||||
msg = "Expected type str for argument {arg}, but got type {type}."
|
||||
msg = msg.format(type=type(value), arg=arg)
|
||||
raise ValueError(msg)
|
||||
|
||||
if valid_values is None:
|
||||
return value
|
||||
|
||||
if value not in valid_values:
|
||||
if custom_msg is not None:
|
||||
msg = custom_msg
|
||||
else:
|
||||
msg = "Unknown value '{value}' for argument {arg}. Valid values are {{{valid_values}}}."
|
||||
msg = msg.format(value=value, arg=arg, valid_values=iterable_to_str(valid_values))
|
||||
raise ValueError(msg)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def _read_pfm(file_name: str, slice_channels: int = 2) -> np.ndarray:
|
||||
"""Read file in .pfm format. Might contain either 1 or 3 channels of data.
|
||||
|
||||
Args:
|
||||
file_name (str): Path to the file.
|
||||
slice_channels (int): Number of channels to slice out of the file.
|
||||
Useful for reading different data formats stored in .pfm files: Optical Flows, Stereo Disparity Maps, etc.
|
||||
"""
|
||||
|
||||
with open(file_name, "rb") as f:
|
||||
header = f.readline().rstrip()
|
||||
if header not in [b"PF", b"Pf"]:
|
||||
raise ValueError("Invalid PFM file")
|
||||
|
||||
dim_match = re.match(rb"^(\d+)\s(\d+)\s$", f.readline())
|
||||
if not dim_match:
|
||||
raise Exception("Malformed PFM header.")
|
||||
w, h = (int(dim) for dim in dim_match.groups())
|
||||
|
||||
scale = float(f.readline().rstrip())
|
||||
if scale < 0: # little-endian
|
||||
endian = "<"
|
||||
scale = -scale
|
||||
else:
|
||||
endian = ">" # big-endian
|
||||
|
||||
data = np.fromfile(f, dtype=endian + "f")
|
||||
|
||||
pfm_channels = 3 if header == b"PF" else 1
|
||||
|
||||
data = data.reshape(h, w, pfm_channels).transpose(2, 0, 1)
|
||||
data = np.flip(data, axis=1) # flip on h dimension
|
||||
data = data[:slice_channels, :, :]
|
||||
return data.astype(np.float32)
|
|
@ -1,104 +0,0 @@
|
|||
import os
|
||||
from typing import Any, Callable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.utils.data as data
|
||||
|
||||
from ..utils import _log_api_usage_once
|
||||
|
||||
|
||||
class VisionDataset(data.Dataset):
|
||||
"""
|
||||
Base Class For making datasets which are compatible with torchvision.
|
||||
It is necessary to override the ``__getitem__`` and ``__len__`` method.
|
||||
Args:
|
||||
root (string): Root directory of dataset.
|
||||
transforms (callable, optional): A function/transforms that takes in
|
||||
an image and a label and returns the transformed versions of both.
|
||||
transform (callable, optional): A function/transform that takes in an PIL image
|
||||
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
||||
target_transform (callable, optional): A function/transform that takes in the
|
||||
target and transforms it.
|
||||
.. note::
|
||||
:attr:`transforms` and the combination of :attr:`transform` and :attr:`target_transform` are mutually exclusive.
|
||||
"""
|
||||
|
||||
_repr_indent = 4
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: str,
|
||||
transforms: Optional[Callable] = None,
|
||||
transform: Optional[Callable] = None,
|
||||
target_transform: Optional[Callable] = None,
|
||||
) -> None:
|
||||
self.root = root
|
||||
|
||||
has_transforms = transforms is not None
|
||||
has_separate_transform = transform is not None or target_transform is not None
|
||||
if has_transforms and has_separate_transform:
|
||||
raise ValueError("Only transforms or transform/target_transform can be passed as argument")
|
||||
|
||||
# for backwards-compatibility
|
||||
self.transform = transform
|
||||
self.target_transform = target_transform
|
||||
|
||||
if has_separate_transform:
|
||||
transforms = StandardTransform(transform, target_transform)
|
||||
self.transforms = transforms
|
||||
|
||||
def __getitem__(self, index: int) -> Any:
|
||||
"""
|
||||
Args:
|
||||
index (int): Index
|
||||
Returns:
|
||||
(Any): Sample and meta data, optionally transformed by the respective transforms.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def __len__(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
def __repr__(self) -> str:
|
||||
head = "Dataset " + self.__class__.__name__
|
||||
body = [f"Number of datapoints: {self.__len__()}"]
|
||||
if self.root is not None:
|
||||
body.append(f"Root location: {self.root}")
|
||||
body += self.extra_repr().splitlines()
|
||||
if hasattr(self, "transforms") and self.transforms is not None:
|
||||
body += [repr(self.transforms)]
|
||||
lines = [head] + [" " * self._repr_indent + line for line in body]
|
||||
return "\n".join(lines)
|
||||
|
||||
def _format_transform_repr(self, transform: Callable, head: str) -> List[str]:
|
||||
lines = transform.__repr__().splitlines()
|
||||
return [f"{head}{lines[0]}"] + ["{}{}".format(" " * len(head), line) for line in lines[1:]]
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
class StandardTransform:
|
||||
def __init__(self, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None) -> None:
|
||||
self.transform = transform
|
||||
self.target_transform = target_transform
|
||||
|
||||
def __call__(self, input: Any, target: Any) -> Tuple[Any, Any]:
|
||||
if self.transform is not None:
|
||||
input = self.transform(input)
|
||||
if self.target_transform is not None:
|
||||
target = self.target_transform(target)
|
||||
return input, target
|
||||
|
||||
def _format_transform_repr(self, transform: Callable, head: str) -> List[str]:
|
||||
lines = transform.__repr__().splitlines()
|
||||
return [f"{head}{lines[0]}"] + ["{}{}".format(" " * len(head), line) for line in lines[1:]]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
body = [self.__class__.__name__]
|
||||
if self.transform is not None:
|
||||
body += self._format_transform_repr(self.transform, "Transform: ")
|
||||
if self.target_transform is not None:
|
||||
body += self._format_transform_repr(self.target_transform, "Target transform: ")
|
||||
|
||||
return "\n".join(body)
|
|
@ -1 +0,0 @@
|
|||
from jittor.transform import *
|
|
@ -1,582 +0,0 @@
|
|||
import collections
|
||||
import math
|
||||
import pathlib
|
||||
import warnings
|
||||
from itertools import repeat
|
||||
from types import FunctionType
|
||||
from typing import Any, BinaryIO, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image, ImageColor, ImageDraw, ImageFont
|
||||
|
||||
__all__ = [
|
||||
"make_grid",
|
||||
"save_image",
|
||||
"draw_bounding_boxes",
|
||||
"draw_segmentation_masks",
|
||||
"draw_keypoints",
|
||||
"flow_to_image",
|
||||
]
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def make_grid(
|
||||
tensor: Union[torch.Tensor, List[torch.Tensor]],
|
||||
nrow: int = 8,
|
||||
padding: int = 2,
|
||||
normalize: bool = False,
|
||||
value_range: Optional[Tuple[int, int]] = None,
|
||||
scale_each: bool = False,
|
||||
pad_value: float = 0.0,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Make a grid of images.
|
||||
|
||||
Args:
|
||||
tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W)
|
||||
or a list of images all of the same size.
|
||||
nrow (int, optional): Number of images displayed in each row of the grid.
|
||||
The final grid size is ``(B / nrow, nrow)``. Default: ``8``.
|
||||
padding (int, optional): amount of padding. Default: ``2``.
|
||||
normalize (bool, optional): If True, shift the image to the range (0, 1),
|
||||
by the min and max values specified by ``value_range``. Default: ``False``.
|
||||
value_range (tuple, optional): tuple (min, max) where min and max are numbers,
|
||||
then these numbers are used to normalize the image. By default, min and max
|
||||
are computed from the tensor.
|
||||
range (tuple. optional):
|
||||
.. warning::
|
||||
This parameter was deprecated in ``0.12`` and will be removed in ``0.14``. Please use ``value_range``
|
||||
instead.
|
||||
scale_each (bool, optional): If ``True``, scale each image in the batch of
|
||||
images separately rather than the (min, max) over all images. Default: ``False``.
|
||||
pad_value (float, optional): Value for the padded pixels. Default: ``0``.
|
||||
|
||||
Returns:
|
||||
grid (Tensor): the tensor containing grid of images.
|
||||
"""
|
||||
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
||||
_log_api_usage_once(make_grid)
|
||||
if not torch.is_tensor(tensor):
|
||||
if isinstance(tensor, list):
|
||||
for t in tensor:
|
||||
if not torch.is_tensor(t):
|
||||
raise TypeError(f"tensor or list of tensors expected, got a list containing {type(t)}")
|
||||
else:
|
||||
raise TypeError(f"tensor or list of tensors expected, got {type(tensor)}")
|
||||
|
||||
if "range" in kwargs.keys():
|
||||
warnings.warn(
|
||||
"The parameter 'range' is deprecated since 0.12 and will be removed in 0.14. "
|
||||
"Please use 'value_range' instead."
|
||||
)
|
||||
value_range = kwargs["range"]
|
||||
|
||||
# if list of tensors, convert to a 4D mini-batch Tensor
|
||||
if isinstance(tensor, list):
|
||||
tensor = torch.stack(tensor, dim=0)
|
||||
|
||||
if tensor.dim() == 2: # single image H x W
|
||||
tensor = tensor.unsqueeze(0)
|
||||
if tensor.dim() == 3: # single image
|
||||
if tensor.size(0) == 1: # if single-channel, convert to 3-channel
|
||||
tensor = torch.cat((tensor, tensor, tensor), 0)
|
||||
tensor = tensor.unsqueeze(0)
|
||||
|
||||
if tensor.dim() == 4 and tensor.size(1) == 1: # single-channel images
|
||||
tensor = torch.cat((tensor, tensor, tensor), 1)
|
||||
|
||||
if normalize is True:
|
||||
tensor = tensor.clone() # avoid modifying tensor in-place
|
||||
if value_range is not None and not isinstance(value_range, tuple):
|
||||
raise TypeError("value_range has to be a tuple (min, max) if specified. min and max are numbers")
|
||||
|
||||
def norm_ip(img, low, high):
|
||||
img.clamp_(min=low, max=high)
|
||||
img.sub_(low).div_(max(high - low, 1e-5))
|
||||
|
||||
def norm_range(t, value_range):
|
||||
if value_range is not None:
|
||||
norm_ip(t, value_range[0], value_range[1])
|
||||
else:
|
||||
norm_ip(t, float(t.min()), float(t.max()))
|
||||
|
||||
if scale_each is True:
|
||||
for t in tensor: # loop over mini-batch dimension
|
||||
norm_range(t, value_range)
|
||||
else:
|
||||
norm_range(tensor, value_range)
|
||||
|
||||
if not isinstance(tensor, torch.Tensor):
|
||||
raise TypeError("tensor should be of type torch.Tensor")
|
||||
if tensor.size(0) == 1:
|
||||
return tensor.squeeze(0)
|
||||
|
||||
# make the mini-batch of images into a grid
|
||||
nmaps = tensor.size(0)
|
||||
xmaps = min(nrow, nmaps)
|
||||
ymaps = int(math.ceil(float(nmaps) / xmaps))
|
||||
height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding)
|
||||
num_channels = tensor.size(1)
|
||||
grid = tensor.new_full((num_channels, height * ymaps + padding, width * xmaps + padding), pad_value)
|
||||
k = 0
|
||||
for y in range(ymaps):
|
||||
for x in range(xmaps):
|
||||
if k >= nmaps:
|
||||
break
|
||||
# Tensor.copy_() is a valid method but seems to be missing from the stubs
|
||||
# https://pytorch.org/docs/stable/tensors.html#torch.Tensor.copy_
|
||||
grid.narrow(1, y * height + padding, height - padding).narrow( # type: ignore[attr-defined]
|
||||
2, x * width + padding, width - padding
|
||||
).copy_(tensor[k])
|
||||
k = k + 1
|
||||
return grid
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def save_image(
|
||||
tensor: Union[torch.Tensor, List[torch.Tensor]],
|
||||
fp: Union[str, pathlib.Path, BinaryIO],
|
||||
format: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Save a given Tensor into an image file.
|
||||
|
||||
Args:
|
||||
tensor (Tensor or list): Image to be saved. If given a mini-batch tensor,
|
||||
saves the tensor as a grid of images by calling ``make_grid``.
|
||||
fp (string or file object): A filename or a file object
|
||||
format(Optional): If omitted, the format to use is determined from the filename extension.
|
||||
If a file object was used instead of a filename, this parameter should always be used.
|
||||
**kwargs: Other arguments are documented in ``make_grid``.
|
||||
"""
|
||||
|
||||
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
||||
_log_api_usage_once(save_image)
|
||||
grid = make_grid(tensor, **kwargs)
|
||||
# Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
|
||||
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
|
||||
im = Image.fromarray(ndarr)
|
||||
im.save(fp, format=format)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def draw_bounding_boxes(
|
||||
image: torch.Tensor,
|
||||
boxes: torch.Tensor,
|
||||
labels: Optional[List[str]] = None,
|
||||
colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None,
|
||||
fill: Optional[bool] = False,
|
||||
width: int = 1,
|
||||
font: Optional[str] = None,
|
||||
font_size: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
"""
|
||||
Draws bounding boxes on given image.
|
||||
The values of the input image should be uint8 between 0 and 255.
|
||||
If fill is True, Resulting Tensor should be saved as PNG image.
|
||||
|
||||
Args:
|
||||
image (Tensor): Tensor of shape (C x H x W) and dtype uint8.
|
||||
boxes (Tensor): Tensor of size (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format. Note that
|
||||
the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and
|
||||
`0 <= ymin < ymax < H`.
|
||||
labels (List[str]): List containing the labels of bounding boxes.
|
||||
colors (color or list of colors, optional): List containing the colors
|
||||
of the boxes or single color for all boxes. The color can be represented as
|
||||
PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
|
||||
By default, random colors are generated for boxes.
|
||||
fill (bool): If `True` fills the bounding box with specified color.
|
||||
width (int): Width of bounding box.
|
||||
font (str): A filename containing a TrueType font. If the file is not found in this filename, the loader may
|
||||
also search in other directories, such as the `fonts/` directory on Windows or `/Library/Fonts/`,
|
||||
`/System/Library/Fonts/` and `~/Library/Fonts/` on macOS.
|
||||
font_size (int): The requested font size in points.
|
||||
|
||||
Returns:
|
||||
img (Tensor[C, H, W]): Image Tensor of dtype uint8 with bounding boxes plotted.
|
||||
"""
|
||||
|
||||
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
||||
_log_api_usage_once(draw_bounding_boxes)
|
||||
if not isinstance(image, torch.Tensor):
|
||||
raise TypeError(f"Tensor expected, got {type(image)}")
|
||||
elif image.dtype != torch.uint8:
|
||||
raise ValueError(f"Tensor uint8 expected, got {image.dtype}")
|
||||
elif image.dim() != 3:
|
||||
raise ValueError("Pass individual images, not batches")
|
||||
elif image.size(0) not in {1, 3}:
|
||||
raise ValueError("Only grayscale and RGB images are supported")
|
||||
elif (boxes[:, 0] > boxes[:, 2]).any() or (boxes[:, 1] > boxes[:, 3]).any():
|
||||
raise ValueError(
|
||||
"Boxes need to be in (xmin, ymin, xmax, ymax) format. Use torchvision.ops.box_convert to convert them"
|
||||
)
|
||||
|
||||
num_boxes = boxes.shape[0]
|
||||
|
||||
if num_boxes == 0:
|
||||
warnings.warn("boxes doesn't contain any box. No box was drawn")
|
||||
return image
|
||||
|
||||
if labels is None:
|
||||
labels: Union[List[str], List[None]] = [None] * num_boxes # type: ignore[no-redef]
|
||||
elif len(labels) != num_boxes:
|
||||
raise ValueError(
|
||||
f"Number of boxes ({num_boxes}) and labels ({len(labels)}) mismatch. Please specify labels for each box."
|
||||
)
|
||||
|
||||
if colors is None:
|
||||
colors = _generate_color_palette(num_boxes)
|
||||
elif isinstance(colors, list):
|
||||
if len(colors) < num_boxes:
|
||||
raise ValueError(f"Number of colors ({len(colors)}) is less than number of boxes ({num_boxes}). ")
|
||||
else: # colors specifies a single color for all boxes
|
||||
colors = [colors] * num_boxes
|
||||
|
||||
colors = [(ImageColor.getrgb(color) if isinstance(color, str) else color) for color in colors]
|
||||
|
||||
if font is None:
|
||||
if font_size is not None:
|
||||
warnings.warn("Argument 'font_size' will be ignored since 'font' is not set.")
|
||||
txt_font = ImageFont.load_default()
|
||||
else:
|
||||
txt_font = ImageFont.truetype(font=font, size=font_size or 10)
|
||||
|
||||
# Handle Grayscale images
|
||||
if image.size(0) == 1:
|
||||
image = torch.tile(image, (3, 1, 1))
|
||||
|
||||
ndarr = image.permute(1, 2, 0).cpu().numpy()
|
||||
img_to_draw = Image.fromarray(ndarr)
|
||||
img_boxes = boxes.to(torch.int64).tolist()
|
||||
|
||||
if fill:
|
||||
draw = ImageDraw.Draw(img_to_draw, "RGBA")
|
||||
else:
|
||||
draw = ImageDraw.Draw(img_to_draw)
|
||||
|
||||
for bbox, color, label in zip(img_boxes, colors, labels): # type: ignore[arg-type]
|
||||
if fill:
|
||||
fill_color = color + (100,)
|
||||
draw.rectangle(bbox, width=width, outline=color, fill=fill_color)
|
||||
else:
|
||||
draw.rectangle(bbox, width=width, outline=color)
|
||||
|
||||
if label is not None:
|
||||
margin = width + 1
|
||||
draw.text((bbox[0] + margin, bbox[1] + margin), label, fill=color, font=txt_font)
|
||||
|
||||
return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def draw_segmentation_masks(
|
||||
image: torch.Tensor,
|
||||
masks: torch.Tensor,
|
||||
alpha: float = 0.8,
|
||||
colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
"""
|
||||
Draws segmentation masks on given RGB image.
|
||||
The values of the input image should be uint8 between 0 and 255.
|
||||
|
||||
Args:
|
||||
image (Tensor): Tensor of shape (3, H, W) and dtype uint8.
|
||||
masks (Tensor): Tensor of shape (num_masks, H, W) or (H, W) and dtype bool.
|
||||
alpha (float): Float number between 0 and 1 denoting the transparency of the masks.
|
||||
0 means full transparency, 1 means no transparency.
|
||||
colors (color or list of colors, optional): List containing the colors
|
||||
of the masks or single color for all masks. The color can be represented as
|
||||
PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
|
||||
By default, random colors are generated for each mask.
|
||||
|
||||
Returns:
|
||||
img (Tensor[C, H, W]): Image Tensor, with segmentation masks drawn on top.
|
||||
"""
|
||||
|
||||
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
||||
_log_api_usage_once(draw_segmentation_masks)
|
||||
if not isinstance(image, torch.Tensor):
|
||||
raise TypeError(f"The image must be a tensor, got {type(image)}")
|
||||
elif image.dtype != torch.uint8:
|
||||
raise ValueError(f"The image dtype must be uint8, got {image.dtype}")
|
||||
elif image.dim() != 3:
|
||||
raise ValueError("Pass individual images, not batches")
|
||||
elif image.size()[0] != 3:
|
||||
raise ValueError("Pass an RGB image. Other Image formats are not supported")
|
||||
if masks.ndim == 2:
|
||||
masks = masks[None, :, :]
|
||||
if masks.ndim != 3:
|
||||
raise ValueError("masks must be of shape (H, W) or (batch_size, H, W)")
|
||||
if masks.dtype != torch.bool:
|
||||
raise ValueError(f"The masks must be of dtype bool. Got {masks.dtype}")
|
||||
if masks.shape[-2:] != image.shape[-2:]:
|
||||
raise ValueError("The image and the masks must have the same height and width")
|
||||
|
||||
num_masks = masks.size()[0]
|
||||
if colors is not None and num_masks > len(colors):
|
||||
raise ValueError(f"There are more masks ({num_masks}) than colors ({len(colors)})")
|
||||
|
||||
if num_masks == 0:
|
||||
warnings.warn("masks doesn't contain any mask. No mask was drawn")
|
||||
return image
|
||||
|
||||
if colors is None:
|
||||
colors = _generate_color_palette(num_masks)
|
||||
|
||||
if not isinstance(colors, list):
|
||||
colors = [colors]
|
||||
if not isinstance(colors[0], (tuple, str)):
|
||||
raise ValueError("colors must be a tuple or a string, or a list thereof")
|
||||
if isinstance(colors[0], tuple) and len(colors[0]) != 3:
|
||||
raise ValueError("It seems that you passed a tuple of colors instead of a list of colors")
|
||||
|
||||
out_dtype = torch.uint8
|
||||
|
||||
colors_ = []
|
||||
for color in colors:
|
||||
if isinstance(color, str):
|
||||
color = ImageColor.getrgb(color)
|
||||
colors_.append(torch.tensor(color, dtype=out_dtype))
|
||||
|
||||
img_to_draw = image.detach().clone()
|
||||
# TODO: There might be a way to vectorize this
|
||||
for mask, color in zip(masks, colors_):
|
||||
img_to_draw[:, mask] = color[:, None]
|
||||
|
||||
out = image * (1 - alpha) + img_to_draw * alpha
|
||||
return out.to(out_dtype)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def draw_keypoints(
|
||||
image: torch.Tensor,
|
||||
keypoints: torch.Tensor,
|
||||
connectivity: Optional[List[Tuple[int, int]]] = None,
|
||||
colors: Optional[Union[str, Tuple[int, int, int]]] = None,
|
||||
radius: int = 2,
|
||||
width: int = 3,
|
||||
) -> torch.Tensor:
|
||||
|
||||
"""
|
||||
Draws Keypoints on given RGB image.
|
||||
The values of the input image should be uint8 between 0 and 255.
|
||||
|
||||
Args:
|
||||
image (Tensor): Tensor of shape (3, H, W) and dtype uint8.
|
||||
keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoints location for each of the N instances,
|
||||
in the format [x, y].
|
||||
connectivity (List[Tuple[int, int]]]): A List of tuple where,
|
||||
each tuple contains pair of keypoints to be connected.
|
||||
colors (str, Tuple): The color can be represented as
|
||||
PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
|
||||
radius (int): Integer denoting radius of keypoint.
|
||||
width (int): Integer denoting width of line connecting keypoints.
|
||||
|
||||
Returns:
|
||||
img (Tensor[C, H, W]): Image Tensor of dtype uint8 with keypoints drawn.
|
||||
"""
|
||||
|
||||
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
||||
_log_api_usage_once(draw_keypoints)
|
||||
if not isinstance(image, torch.Tensor):
|
||||
raise TypeError(f"The image must be a tensor, got {type(image)}")
|
||||
elif image.dtype != torch.uint8:
|
||||
raise ValueError(f"The image dtype must be uint8, got {image.dtype}")
|
||||
elif image.dim() != 3:
|
||||
raise ValueError("Pass individual images, not batches")
|
||||
elif image.size()[0] != 3:
|
||||
raise ValueError("Pass an RGB image. Other Image formats are not supported")
|
||||
|
||||
if keypoints.ndim != 3:
|
||||
raise ValueError("keypoints must be of shape (num_instances, K, 2)")
|
||||
|
||||
ndarr = image.permute(1, 2, 0).cpu().numpy()
|
||||
img_to_draw = Image.fromarray(ndarr)
|
||||
draw = ImageDraw.Draw(img_to_draw)
|
||||
img_kpts = keypoints.to(torch.int64).tolist()
|
||||
|
||||
for kpt_id, kpt_inst in enumerate(img_kpts):
|
||||
for inst_id, kpt in enumerate(kpt_inst):
|
||||
x1 = kpt[0] - radius
|
||||
x2 = kpt[0] + radius
|
||||
y1 = kpt[1] - radius
|
||||
y2 = kpt[1] + radius
|
||||
draw.ellipse([x1, y1, x2, y2], fill=colors, outline=None, width=0)
|
||||
|
||||
if connectivity:
|
||||
for connection in connectivity:
|
||||
start_pt_x = kpt_inst[connection[0]][0]
|
||||
start_pt_y = kpt_inst[connection[0]][1]
|
||||
|
||||
end_pt_x = kpt_inst[connection[1]][0]
|
||||
end_pt_y = kpt_inst[connection[1]][1]
|
||||
|
||||
draw.line(
|
||||
((start_pt_x, start_pt_y), (end_pt_x, end_pt_y)),
|
||||
width=width,
|
||||
)
|
||||
|
||||
return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8)
|
||||
|
||||
|
||||
# Flow visualization code adapted from https://github.com/tomrunia/OpticalFlow_Visualization
|
||||
@torch.no_grad()
|
||||
def flow_to_image(flow: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
"""
|
||||
Converts a flow to an RGB image.
|
||||
|
||||
Args:
|
||||
flow (Tensor): Flow of shape (N, 2, H, W) or (2, H, W) and dtype torch.float.
|
||||
|
||||
Returns:
|
||||
img (Tensor): Image Tensor of dtype uint8 where each color corresponds
|
||||
to a given flow direction. Shape is (N, 3, H, W) or (3, H, W) depending on the input.
|
||||
"""
|
||||
|
||||
if flow.dtype != torch.float:
|
||||
raise ValueError(f"Flow should be of dtype torch.float, got {flow.dtype}.")
|
||||
|
||||
orig_shape = flow.shape
|
||||
if flow.ndim == 3:
|
||||
flow = flow[None] # Add batch dim
|
||||
|
||||
if flow.ndim != 4 or flow.shape[1] != 2:
|
||||
raise ValueError(f"Input flow should have shape (2, H, W) or (N, 2, H, W), got {orig_shape}.")
|
||||
|
||||
max_norm = torch.sum(flow**2, dim=1).sqrt().max()
|
||||
epsilon = torch.finfo((flow).dtype).eps
|
||||
normalized_flow = flow / (max_norm + epsilon)
|
||||
img = _normalized_flow_to_image(normalized_flow)
|
||||
|
||||
if len(orig_shape) == 3:
|
||||
img = img[0] # Remove batch dim
|
||||
return img
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
"""
|
||||
Converts a batch of normalized flow to an RGB image.
|
||||
|
||||
Args:
|
||||
normalized_flow (torch.Tensor): Normalized flow tensor of shape (N, 2, H, W)
|
||||
Returns:
|
||||
img (Tensor(N, 3, H, W)): Flow visualization image of dtype uint8.
|
||||
"""
|
||||
|
||||
N, _, H, W = normalized_flow.shape
|
||||
device = normalized_flow.device
|
||||
flow_image = torch.zeros((N, 3, H, W), dtype=torch.uint8, device=device)
|
||||
colorwheel = _make_colorwheel().to(device) # shape [55x3]
|
||||
num_cols = colorwheel.shape[0]
|
||||
norm = torch.sum(normalized_flow**2, dim=1).sqrt()
|
||||
a = torch.atan2(-normalized_flow[:, 1, :, :], -normalized_flow[:, 0, :, :]) / torch.pi
|
||||
fk = (a + 1) / 2 * (num_cols - 1)
|
||||
k0 = torch.floor(fk).to(torch.long)
|
||||
k1 = k0 + 1
|
||||
k1[k1 == num_cols] = 0
|
||||
f = fk - k0
|
||||
|
||||
for c in range(colorwheel.shape[1]):
|
||||
tmp = colorwheel[:, c]
|
||||
col0 = tmp[k0] / 255.0
|
||||
col1 = tmp[k1] / 255.0
|
||||
col = (1 - f) * col0 + f * col1
|
||||
col = 1 - norm * (1 - col)
|
||||
flow_image[:, c, :, :] = torch.floor(255 * col)
|
||||
return flow_image
|
||||
|
||||
|
||||
def _make_colorwheel() -> torch.Tensor:
|
||||
"""
|
||||
Generates a color wheel for optical flow visualization as presented in:
|
||||
Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
|
||||
URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf.
|
||||
|
||||
Returns:
|
||||
colorwheel (Tensor[55, 3]): Colorwheel Tensor.
|
||||
"""
|
||||
|
||||
RY = 15
|
||||
YG = 6
|
||||
GC = 4
|
||||
CB = 11
|
||||
BM = 13
|
||||
MR = 6
|
||||
|
||||
ncols = RY + YG + GC + CB + BM + MR
|
||||
colorwheel = torch.zeros((ncols, 3))
|
||||
col = 0
|
||||
|
||||
# RY
|
||||
colorwheel[0:RY, 0] = 255
|
||||
colorwheel[0:RY, 1] = torch.floor(255 * torch.arange(0, RY) / RY)
|
||||
col = col + RY
|
||||
# YG
|
||||
colorwheel[col : col + YG, 0] = 255 - torch.floor(255 * torch.arange(0, YG) / YG)
|
||||
colorwheel[col : col + YG, 1] = 255
|
||||
col = col + YG
|
||||
# GC
|
||||
colorwheel[col : col + GC, 1] = 255
|
||||
colorwheel[col : col + GC, 2] = torch.floor(255 * torch.arange(0, GC) / GC)
|
||||
col = col + GC
|
||||
# CB
|
||||
colorwheel[col : col + CB, 1] = 255 - torch.floor(255 * torch.arange(CB) / CB)
|
||||
colorwheel[col : col + CB, 2] = 255
|
||||
col = col + CB
|
||||
# BM
|
||||
colorwheel[col : col + BM, 2] = 255
|
||||
colorwheel[col : col + BM, 0] = torch.floor(255 * torch.arange(0, BM) / BM)
|
||||
col = col + BM
|
||||
# MR
|
||||
colorwheel[col : col + MR, 2] = 255 - torch.floor(255 * torch.arange(MR) / MR)
|
||||
colorwheel[col : col + MR, 0] = 255
|
||||
return colorwheel
|
||||
|
||||
|
||||
def _generate_color_palette(num_objects: int):
|
||||
palette = torch.tensor([2**25 - 1, 2**15 - 1, 2**21 - 1])
|
||||
return [tuple((i * palette) % 255) for i in range(num_objects)]
|
||||
|
||||
|
||||
def _log_api_usage_once(obj: Any) -> None:
|
||||
|
||||
"""
|
||||
Logs API usage(module and name) within an organization.
|
||||
In a large ecosystem, it's often useful to track the PyTorch and
|
||||
TorchVision APIs usage. This API provides the similar functionality to the
|
||||
logging module in the Python stdlib. It can be used for debugging purpose
|
||||
to log which methods are used and by default it is inactive, unless the user
|
||||
manually subscribes a logger via the `SetAPIUsageLogger method <https://github.com/pytorch/pytorch/blob/eb3b9fe719b21fae13c7a7cf3253f970290a573e/c10/util/Logging.cpp#L114>`_.
|
||||
Please note it is triggered only once for the same API call within a process.
|
||||
It does not collect any data from open-source users since it is no-op by default.
|
||||
For more information, please refer to
|
||||
* PyTorch note: https://pytorch.org/docs/stable/notes/large_scale_deployments.html#api-usage-logging;
|
||||
* Logging policy: https://github.com/pytorch/vision/issues/5052;
|
||||
|
||||
Args:
|
||||
obj (class instance or method): an object to extract info from.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def _make_ntuple(x: Any, n: int) -> Tuple[Any, ...]:
|
||||
"""
|
||||
Make n-tuple from input x. If x is an iterable, then we just convert it to tuple.
|
||||
Otherwise we will make a tuple of length n, all with value of x.
|
||||
reference: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/utils.py#L8
|
||||
|
||||
Args:
|
||||
x (Any): input value
|
||||
n (int): length of the resulting tuple
|
||||
"""
|
||||
if isinstance(x, collections.abc.Iterable):
|
||||
return tuple(x)
|
||||
return tuple(repeat(x, n))
|
|
@ -3537,38 +3537,4 @@ def mish(x, inplace=False):
|
|||
def skip_init(module_cls, *args, **kw):
|
||||
return module_cls(*args, **kw)
|
||||
|
||||
'''
|
||||
Only for extern/acl
|
||||
'''
|
||||
class FlashAttention(Module):
|
||||
def __init__(self,headnum,
|
||||
layout = "BSND",
|
||||
prefix = None,
|
||||
qstart = None,
|
||||
kvstart = None,
|
||||
scale = 1.0,
|
||||
prob = 1.0,
|
||||
pretokens = 2147483647,
|
||||
nexttokens = 2147483647,
|
||||
innerprecise = 0,
|
||||
sparsemode = 0,
|
||||
psetype = 1):
|
||||
self.headnum = headnum
|
||||
self.layout = layout
|
||||
self.prefix = prefix
|
||||
self.qstart = qstart
|
||||
self.kvstart = kvstart
|
||||
self.scale = scale
|
||||
self.prob = prob
|
||||
self.pretokens = pretokens
|
||||
self.nexttokens = nexttokens
|
||||
self.innerprecise = innerprecise
|
||||
self.sparsemode = sparsemode
|
||||
self.psetype = psetype
|
||||
|
||||
def execute(self,q,k,v,
|
||||
realshift = None,
|
||||
dropMask = None,
|
||||
paddingMask = None,
|
||||
attenMask = None):
|
||||
pass
|
||||
|
|
|
@ -2,6 +2,7 @@ import jittor as jt
|
|||
from jittor.nn import ComplexNumber
|
||||
import unittest
|
||||
import numpy as np
|
||||
from functools import partial
|
||||
|
||||
__skip_torch_test = False
|
||||
try:
|
||||
|
@ -10,6 +11,15 @@ except:
|
|||
__skip_torch_test = True
|
||||
|
||||
class TestResultAndGrad:
|
||||
def flatten_list(self, list_like):
|
||||
results = []
|
||||
if isinstance(list_like, (list, tuple)):
|
||||
for x in list_like:
|
||||
results.extend(self.flatten_list(x))
|
||||
return results
|
||||
else:
|
||||
return [list_like]
|
||||
|
||||
def check_results(self, rlist1, rlist2):
|
||||
assert len(rlist1) == len(rlist2)
|
||||
for r1, r2 in zip(rlist1, rlist2):
|
||||
|
@ -36,13 +46,21 @@ class TestResultAndGrad:
|
|||
grads.append(g.detach().cpu().numpy())
|
||||
return grads
|
||||
|
||||
def run_jittor_op(self, op, input_list, weights=None):
|
||||
def _np_to_jittor(x:np.ndarray):
|
||||
if x.dtype == np.complex64 or x.dtype == np.complex128:
|
||||
nx = np.stack([np.real(x), np.imag(x)], axis=-1)
|
||||
return ComplexNumber(jt.array(nx, dtype=jt.float32), is_concat_value=True)
|
||||
elif x.dtype == np.float32 or x.dtype == np.float64:
|
||||
return jt.array(x, dtype=jt.float32)
|
||||
def run_jittor_op(self, op, input_list, weights=None, key_names=None, **kwargs):
|
||||
def _np_to_jittor(x):
|
||||
if isinstance(x, np.ndarray):
|
||||
if x.dtype == np.complex64 or x.dtype == np.complex128:
|
||||
nx = np.stack([np.real(x), np.imag(x)], axis=-1)
|
||||
return ComplexNumber(jt.array(nx, dtype=jt.float32), is_concat_value=True)
|
||||
elif x.dtype == np.float32 or x.dtype == np.float64:
|
||||
return jt.array(x, dtype=jt.float32)
|
||||
else:
|
||||
assert False
|
||||
elif isinstance(x, (list, tuple)):
|
||||
nx = [_np_to_jittor(vx) for vx in x]
|
||||
if isinstance(x, tuple):
|
||||
return tuple(nx)
|
||||
return nx
|
||||
else:
|
||||
assert False
|
||||
def _jittor_to_np(x):
|
||||
|
@ -51,11 +69,19 @@ class TestResultAndGrad:
|
|||
elif isinstance(x, ComplexNumber):
|
||||
return x.real.numpy() + 1j * x.imag.numpy()
|
||||
assert False
|
||||
|
||||
ninput_list = [_np_to_jittor(x) for x in input_list]
|
||||
output_list = op(*ninput_list)
|
||||
|
||||
if key_names != None:
|
||||
assert len(ninput_list) == len(key_names)
|
||||
nkwargs = kwargs.copy()
|
||||
for k, v in zip(key_names, ninput_list):
|
||||
nkwargs[k] = v
|
||||
output_list = op(**nkwargs)
|
||||
else:
|
||||
output_list = op(*ninput_list, **kwargs)
|
||||
if isinstance(output_list, (jt.Var, ComplexNumber)):
|
||||
output_list = [output_list]
|
||||
output_list = self.flatten_list(output_list)
|
||||
losses = []
|
||||
if weights is None:
|
||||
weights = []
|
||||
|
@ -73,15 +99,31 @@ class TestResultAndGrad:
|
|||
output_list = [_jittor_to_np(x) for x in output_list]
|
||||
return ninput_list, output_list, losses, weights
|
||||
|
||||
def run_torch_op(self, op, input_list, weights=None):
|
||||
def _np_to_torch(x:np.ndarray):
|
||||
return torch.from_numpy(x).requires_grad_(True)
|
||||
def run_torch_op(self, op, input_list, weights=None, key_names=None, **kwargs):
|
||||
def _np_to_torch(x):
|
||||
if isinstance(x, np.ndarray):
|
||||
return torch.from_numpy(x).requires_grad_(True)
|
||||
elif isinstance(x, (list, tuple)):
|
||||
nx = [_np_to_torch(vx) for vx in x]
|
||||
if isinstance(x, tuple):
|
||||
return tuple(nx)
|
||||
return nx
|
||||
else:
|
||||
assert False
|
||||
def _torch_to_np(x:torch.Tensor) -> np.ndarray:
|
||||
return x.detach().cpu().numpy()
|
||||
ninput_list = [_np_to_torch(x) for x in input_list]
|
||||
output_list = op(*ninput_list)
|
||||
if key_names != None:
|
||||
assert len(ninput_list) == len(key_names)
|
||||
nkwargs = kwargs.copy()
|
||||
for k, v in zip(key_names, ninput_list):
|
||||
nkwargs[k] = v
|
||||
output_list = op(**nkwargs)
|
||||
else:
|
||||
output_list = op(*ninput_list, **kwargs)
|
||||
if isinstance(output_list, torch.Tensor):
|
||||
output_list = [output_list]
|
||||
output_list = self.flatten_list(output_list)
|
||||
losses = []
|
||||
if weights is None:
|
||||
weights = []
|
||||
|
@ -99,10 +141,10 @@ class TestResultAndGrad:
|
|||
output_list = [_torch_to_np(x) for x in output_list]
|
||||
return ninput_list, output_list, losses, weights
|
||||
|
||||
def check_op_with_torch(self, jittor_op, torch_op, input_list, check_grad=True):
|
||||
def check_op_with_torch(self, jittor_op, torch_op, input_list, check_grad=True, jittor_knames=None, torch_knames=None, **kwargs):
|
||||
weights = None
|
||||
jittor_input, jittor_output, jittor_losses, weights = self.run_jittor_op(jittor_op, input_list, weights)
|
||||
torch_input, torch_output, torch_losses, weights = self.run_torch_op(torch_op, input_list, weights)
|
||||
jittor_input, jittor_output, jittor_losses, weights = self.run_jittor_op(jittor_op, input_list, weights, key_names=jittor_knames, **kwargs)
|
||||
torch_input, torch_output, torch_losses, weights = self.run_torch_op(torch_op, input_list, weights, key_names=torch_knames, **kwargs)
|
||||
self.check_results(jittor_output, torch_output)
|
||||
|
||||
if check_grad:
|
||||
|
@ -195,6 +237,249 @@ class TestComplexLinalg(unittest.TestCase, TestResultAndGrad):
|
|||
inputs = [m1]
|
||||
self.check_op_with_numpy(jt.linalg.svd, np.linalg.svd, inputs)
|
||||
|
||||
class TestTensordot(unittest.TestCase, TestResultAndGrad):
|
||||
def random_complex_matrix(self, shape):
|
||||
r = np.random.randn(*shape)
|
||||
i = np.random.randn(*shape)
|
||||
return r + 1j * i
|
||||
|
||||
def random_real_matrix(self, shape):
|
||||
return np.random.randn(*shape)
|
||||
|
||||
def test_complex_tensordot_numberdim(self):
|
||||
s1 = (3, 4, 5)
|
||||
s2 = (4, 5, 6)
|
||||
dims = 2
|
||||
m1 = self.random_complex_matrix(s1)
|
||||
m2 = self.random_complex_matrix(s2)
|
||||
inputs = [m1, m2]
|
||||
self.check_op_with_torch(jt.nn.tensordot, torch.tensordot, inputs, dims = dims)
|
||||
|
||||
def test_complex_tensordot_tupledim(self):
|
||||
s1 = (3, 5, 4, 6)
|
||||
s2 = (6, 4, 5, 3)
|
||||
dims = ([2, 1, 3], [1, 2, 0])
|
||||
m1 = self.random_complex_matrix(s1)
|
||||
m2 = self.random_complex_matrix(s2)
|
||||
inputs = [m1, m2]
|
||||
self.check_op_with_torch(jt.nn.tensordot, torch.tensordot, inputs, dims = dims)
|
||||
|
||||
def test_real_tensordot_numberdim(self):
|
||||
s1 = (3, 4, 5)
|
||||
s2 = (4, 5, 6)
|
||||
dims = 2
|
||||
m1 = self.random_real_matrix(s1)
|
||||
m2 = self.random_real_matrix(s2)
|
||||
inputs = [m1, m2]
|
||||
self.check_op_with_torch(jt.nn.tensordot, torch.tensordot, inputs, dims = dims)
|
||||
|
||||
def test_real_tensordot_tupledim(self):
|
||||
s1 = (3, 5, 4, 6)
|
||||
s2 = (6, 4, 5, 3)
|
||||
dims = ([2, 1, 3], [1, 2, 0])
|
||||
m1 = self.random_real_matrix(s1)
|
||||
m2 = self.random_real_matrix(s2)
|
||||
inputs = [m1, m2]
|
||||
self.check_op_with_torch(jt.nn.tensordot, torch.tensordot, inputs, dims = dims)
|
||||
|
||||
class TestKron(unittest.TestCase, TestResultAndGrad):
|
||||
def random_complex_matrix(self, shape):
|
||||
r = np.random.randn(*shape)
|
||||
i = np.random.randn(*shape)
|
||||
return r + 1j * i
|
||||
|
||||
def random_real_matrix(self, shape):
|
||||
return np.random.randn(*shape)
|
||||
|
||||
def test_complex_firstlarge(self):
|
||||
s1 = (2, 3, 4)
|
||||
s2 = (5, 2)
|
||||
m1 = self.random_complex_matrix(s1)
|
||||
m2 = self.random_complex_matrix(s2)
|
||||
inputs = [m1, m2]
|
||||
self.check_op_with_torch(jt.nn.kron, torch.kron, inputs)
|
||||
|
||||
def test_complex_second_large(self):
|
||||
s1 = (2, 3)
|
||||
s2 = (5, 2, 4)
|
||||
m1 = self.random_complex_matrix(s1)
|
||||
m2 = self.random_complex_matrix(s2)
|
||||
inputs = [m1, m2]
|
||||
self.check_op_with_torch(jt.nn.kron, torch.kron, inputs)
|
||||
|
||||
def test_real_firstlarge(self):
|
||||
s1 = (2, 3, 4)
|
||||
s2 = (5, 2)
|
||||
m1 = self.random_real_matrix(s1)
|
||||
m2 = self.random_real_matrix(s2)
|
||||
inputs = [m1, m2]
|
||||
self.check_op_with_torch(jt.nn.kron, torch.kron, inputs)
|
||||
|
||||
def test_real_second_large(self):
|
||||
s1 = (2, 3)
|
||||
s2 = (5, 2, 4)
|
||||
m1 = self.random_real_matrix(s1)
|
||||
m2 = self.random_real_matrix(s2)
|
||||
inputs = [m1, m2]
|
||||
self.check_op_with_torch(jt.nn.kron, torch.kron, inputs)
|
||||
|
||||
@unittest.skipIf(__skip_torch_test, "No Torch found")
|
||||
class TestGradFunctional(unittest.TestCase, TestResultAndGrad):
|
||||
def random_complex_matrix(self, shape):
|
||||
r = np.random.randn(*shape)
|
||||
i = np.random.randn(*shape)
|
||||
return r + 1j * i
|
||||
|
||||
def random_real_matrix(self, shape):
|
||||
return np.random.randn(*shape) * 0.0 + 1.0
|
||||
|
||||
def test_real_jvp_exp(self):
|
||||
def exp_reducer(x):
|
||||
return x.exp().sum(dim=1)
|
||||
s1 = (5, 6)
|
||||
m1 = self.random_real_matrix(s1)
|
||||
m2 = self.random_real_matrix(s1)
|
||||
inputs = [m1, m2]
|
||||
self.check_op_with_torch(
|
||||
partial(jt.gradfunctional.jvp, func=exp_reducer, create_graph=True),
|
||||
partial(torch.autograd.functional.jvp, func=exp_reducer, create_graph=True),
|
||||
inputs,
|
||||
jittor_knames = ['inputs', 'v'],
|
||||
torch_knames = ['inputs', 'v'],
|
||||
check_grad=False)
|
||||
|
||||
def test_complex_jvp_exp(self):
|
||||
def exp_reducer(x):
|
||||
return x.exp().sum(1)
|
||||
s1 = (5, 6)
|
||||
m1 = self.random_complex_matrix(s1)
|
||||
m2 = self.random_complex_matrix(s1)
|
||||
inputs = [m1, m2]
|
||||
self.check_op_with_torch(
|
||||
partial(jt.gradfunctional.jvp, func=exp_reducer, create_graph=True),
|
||||
partial(torch.autograd.functional.jvp, func=exp_reducer, create_graph=True),
|
||||
inputs,
|
||||
jittor_knames = ['inputs', 'v'],
|
||||
torch_knames = ['inputs', 'v'],
|
||||
check_grad=False,
|
||||
)
|
||||
|
||||
def test_real_jvp_add(self):
|
||||
w1, w2 = np.random.rand(), np.random.rand()
|
||||
def adder(x, y):
|
||||
return w1 * x + w2 * y
|
||||
s1 = (5, 6)
|
||||
m1 = self.random_real_matrix(s1)
|
||||
m2 = self.random_real_matrix(s1)
|
||||
m3 = self.random_real_matrix(s1)
|
||||
m4 = self.random_real_matrix(s1)
|
||||
inputs = [(m1, m2), (m3, m4)]
|
||||
self.check_op_with_torch(
|
||||
partial(jt.gradfunctional.jvp, func=adder, create_graph=True),
|
||||
partial(torch.autograd.functional.jvp, func=adder, create_graph=True),
|
||||
inputs,
|
||||
jittor_knames = ['inputs', 'v'],
|
||||
torch_knames = ['inputs', 'v'],
|
||||
check_grad=False,
|
||||
)
|
||||
|
||||
def test_complex_jvp_add(self):
|
||||
w1r, w1i = np.random.rand(), np.random.rand()
|
||||
w2r, w2i = np.random.rand(), np.random.rand()
|
||||
def adder_pt(x, y):
|
||||
return (w1r + 1j * w1i) * x + (w2r + 1j * w2i) * y
|
||||
def adder_jt(x, y):
|
||||
w1 = ComplexNumber(real=jt.array(w1r).reshape(1,1), imag = jt.array(w1i).reshape(1,1))
|
||||
w2 = ComplexNumber(real=jt.array(w2r).reshape(1,1), imag = jt.array(w2i).reshape(1,1))
|
||||
return w1 * x + w2 * y
|
||||
s1 = (5, 6)
|
||||
m1 = self.random_complex_matrix(s1)
|
||||
m2 = self.random_complex_matrix(s1)
|
||||
m3 = self.random_complex_matrix(s1)
|
||||
m4 = self.random_complex_matrix(s1)
|
||||
inputs = [(m1, m2), (m3, m4)]
|
||||
self.check_op_with_torch(
|
||||
partial(jt.gradfunctional.jvp, func=adder_jt, create_graph=True),
|
||||
partial(torch.autograd.functional.jvp, func=adder_pt, create_graph=True),
|
||||
inputs,
|
||||
jittor_knames = ['inputs', 'v'],
|
||||
torch_knames = ['inputs', 'v'],
|
||||
check_grad=False,
|
||||
)
|
||||
|
||||
def test_real_vjp_exp(self):
|
||||
def exp_reducer(x):
|
||||
return x.exp().sum(dim=1)
|
||||
s1 = (5, 6)
|
||||
s2 = (5,)
|
||||
m1 = self.random_real_matrix(s1)
|
||||
m2 = self.random_real_matrix(s2)
|
||||
inputs = [m1, m2]
|
||||
self.check_op_with_torch(
|
||||
partial(jt.gradfunctional.vjp, func=exp_reducer),
|
||||
partial(torch.autograd.functional.vjp, func=exp_reducer),
|
||||
inputs,
|
||||
jittor_knames = ['inputs', 'v'],
|
||||
torch_knames = ['inputs', 'v'],
|
||||
check_grad=False)
|
||||
|
||||
def test_complex_vjp_exp(self):
|
||||
def exp_reducer(x):
|
||||
return x.exp().sum(1)
|
||||
s1 = (5, 6)
|
||||
s2 = (5,)
|
||||
m1 = self.random_real_matrix(s1)
|
||||
m2 = self.random_real_matrix(s2)
|
||||
inputs = [m1, m2]
|
||||
self.check_op_with_torch(
|
||||
partial(jt.gradfunctional.vjp, func=exp_reducer, create_graph=True),
|
||||
partial(torch.autograd.functional.vjp, func=exp_reducer, create_graph=True),
|
||||
inputs,
|
||||
jittor_knames = ['inputs', 'v'],
|
||||
torch_knames = ['inputs', 'v'],
|
||||
check_grad=False,
|
||||
)
|
||||
|
||||
def test_real_vjp_add(self):
|
||||
w1, w2 = np.random.rand(), np.random.rand()
|
||||
def adder(x, y):
|
||||
return w1 * x + w2 * y
|
||||
s1 = (5, 6)
|
||||
m1 = self.random_real_matrix(s1)
|
||||
m2 = self.random_real_matrix(s1)
|
||||
m3 = self.random_real_matrix(s1)
|
||||
inputs = [(m1, m2), m3]
|
||||
self.check_op_with_torch(
|
||||
partial(jt.gradfunctional.vjp, func=adder, create_graph=True),
|
||||
partial(torch.autograd.functional.vjp, func=adder, create_graph=True),
|
||||
inputs,
|
||||
jittor_knames = ['inputs', 'v'],
|
||||
torch_knames = ['inputs', 'v'],
|
||||
check_grad=False,
|
||||
)
|
||||
|
||||
def test_complex_vjp_add(self):
|
||||
w1r, w1i = np.random.rand(), np.random.rand()
|
||||
w2r, w2i = np.random.rand(), np.random.rand()
|
||||
def adder_pt(x, y):
|
||||
return (w1r + 1j * w1i) * x + (w2r + 1j * w2i) * y
|
||||
def adder_jt(x, y):
|
||||
w1 = ComplexNumber(real=jt.array(w1r).reshape(1,1), imag = jt.array(w1i).reshape(1,1))
|
||||
w2 = ComplexNumber(real=jt.array(w2r).reshape(1,1), imag = jt.array(w2i).reshape(1,1))
|
||||
return w1 * x + w2 * y
|
||||
s1 = (5, 6)
|
||||
m1 = self.random_complex_matrix(s1)
|
||||
m2 = self.random_complex_matrix(s1)
|
||||
m3 = self.random_complex_matrix(s1)
|
||||
inputs = [(m1, m2), (m3)]
|
||||
self.check_op_with_torch(
|
||||
partial(jt.gradfunctional.vjp, func=adder_jt, create_graph=True),
|
||||
partial(torch.autograd.functional.vjp, func=adder_pt, create_graph=True),
|
||||
inputs,
|
||||
jittor_knames = ['inputs', 'v'],
|
||||
torch_knames = ['inputs', 'v'],
|
||||
check_grad=False,
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -0,0 +1,112 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2023 Jittor. All Rights Reserved.
|
||||
# Maintainers: Shizhan Lu <578752274@qq.com>.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
|
||||
from jittor.compile_extern import cusparse_ops
|
||||
|
||||
|
||||
class TestSpmmCsrOp(unittest.TestCase):
|
||||
@unittest.skipIf(not jt.has_cuda, "No CUDA support, skipping test")
|
||||
@jt.flag_scope(use_cuda=1, lazy_execution=0)
|
||||
def test_spmm_csr_forward_float32_int32(self):
|
||||
x = jt.array([[3.0, 2.0, 1.0], [3.0, 2.0, 1.0], [3.0, 2.0, 1.0]], dtype="float32")
|
||||
col_indices = jt.array([0, 1, 1, 2], dtype="int32")
|
||||
row_offset = jt.array([0, 2, 3, 4], dtype="int32")
|
||||
csr_weight = jt.array([3.0, 1.0, 4.0, 2.0], dtype="float32")
|
||||
output = jt.zeros((3, 3), dtype="float32")
|
||||
cusparse_ops.cusparse_spmmcsr(
|
||||
output, x, col_indices, csr_weight, row_offset,
|
||||
3, 3 ,False, False
|
||||
).fetch_sync()
|
||||
expected_output = np.array([
|
||||
[12.0, 8.0, 4.0],
|
||||
[12.0, 8.0, 4.0],
|
||||
[6.0, 4.0, 2.0]
|
||||
])
|
||||
np.testing.assert_allclose(output.data, expected_output, atol=1e-5)
|
||||
|
||||
@unittest.skipIf(not jt.has_cuda, "No CUDA support, skipping test")
|
||||
@jt.flag_scope(use_cuda=1, lazy_execution=0)
|
||||
def test_spmm_csr_forward_float16_int32(self):
|
||||
x = jt.array([[3.0, 2.0, 1.0], [3.0, 2.0, 1.0], [3.0, 2.0, 1.0]], dtype="float16")
|
||||
col_indices = jt.array([0, 1, 1, 2], dtype="int32")
|
||||
row_offset = jt.array([0, 2, 3, 4], dtype="int32")
|
||||
csr_weight = jt.array([3.0, 1.0, 4.0, 2.0], dtype="float16")
|
||||
output = jt.zeros((3, 3), dtype="float16")
|
||||
cusparse_ops.cusparse_spmmcsr(
|
||||
output, x, col_indices, csr_weight, row_offset,
|
||||
3, 3,False, False
|
||||
).fetch_sync()
|
||||
expected_output = np.array([
|
||||
[12.0, 8.0, 4.0],
|
||||
[12.0, 8.0, 4.0],
|
||||
[6.0, 4.0, 2.0]
|
||||
], dtype="float16")
|
||||
np.testing.assert_allclose(output.data, expected_output, atol=1e-5)
|
||||
|
||||
# @unittest.skipIf(not jt.has_cuda, "No CUDA support, skipping test")
|
||||
# @jt.flag_scope(use_cuda=1, lazy_execution=0)
|
||||
# def test_spmm_csr_forward_float64_int32(self):
|
||||
# x = jt.array([[3.0, 2.0, 1.0], [3.0, 2.0, 1.0], [3.0, 2.0, 1.0]], dtype="float64")
|
||||
# col_indices = jt.array([0, 1, 1, 2], dtype="int32")
|
||||
# row_offset = jt.array([0, 2, 3, 4], dtype="int32")
|
||||
# csr_weight = jt.array([3.0, 1.0, 4.0, 2.0], dtype="float64")
|
||||
# output = jt.zeros((3, 3), dtype="float64")
|
||||
# cusparse_ops.cusparse_spmmcsr(
|
||||
# output, x, col_indices, csr_weight, row_offset,
|
||||
# 3, 3,False, False
|
||||
# ).fetch_sync()
|
||||
# expected_output = np.array([
|
||||
# [12.0, 8.0, 4.0],
|
||||
# [12.0, 8.0, 4.0],
|
||||
# [6.0, 4.0, 2.0]
|
||||
# ], dtype="float64")
|
||||
# np.testing.assert_allclose(output.data, expected_output, atol=1e-5)
|
||||
|
||||
|
||||
@unittest.skipIf(not jt.has_cuda, "No CUDA support, skipping test")
|
||||
@jt.flag_scope(use_cuda=1, lazy_execution=0)
|
||||
def test_spmm_csr_forward_float32_int64(self):
|
||||
x = jt.array([[3.0, 2.0, 1.0], [3.0, 2.0, 1.0], [3.0, 2.0, 1.0]], dtype="float32")
|
||||
col_indices = jt.array([0, 1, 1, 2], dtype="int64")
|
||||
row_offset = jt.array([0, 2, 3, 4], dtype="int64")
|
||||
csr_weight = jt.array([3.0, 1.0, 4.0, 2.0], dtype="float32")
|
||||
output = jt.zeros((3, 3), dtype="float32")
|
||||
cusparse_ops.cusparse_spmmcsr(
|
||||
output, x, col_indices, csr_weight, row_offset,
|
||||
3, 3,False, False
|
||||
).fetch_sync()
|
||||
expected_output = np.array([
|
||||
[12.0, 8.0, 4.0],
|
||||
[12.0, 8.0, 4.0],
|
||||
[6.0, 4.0, 2.0]
|
||||
], dtype="float32")
|
||||
np.testing.assert_allclose(output.data, expected_output, atol=1e-5)
|
||||
|
||||
@unittest.skipIf(not jt.has_cuda, "No CUDA support, skipping test")
|
||||
@jt.flag_scope(use_cuda=1, lazy_execution=0)
|
||||
def test_spmm_coo(self):
|
||||
x=jt.array([[3.0, 2.0, 1.0],[4.0, 2.0, 2.0],[1.0, 2.0, 3.0]], dtype="float32")
|
||||
edge_index=jt.array([[0,0,1,2],[1,2,2,1]],dtype="int32")
|
||||
row_indices=edge_index[0,:]
|
||||
col_indices=edge_index[1,:]
|
||||
edge_weight = jt.array([1.0, 1.0, 1.0, 1.0], dtype="float32")
|
||||
feature_dim=jt.size(x,1)
|
||||
output=jt.zeros(3,feature_dim)
|
||||
cusparse_ops.cusparse_spmmcoo(output,x,row_indices,col_indices,edge_weight,3,3,False, False).fetch_sync()
|
||||
print("Output:", output)
|
||||
expected_output = np.array([
|
||||
[5.0, 4.0, 5.0],
|
||||
[1.0, 2.0, 3.0],
|
||||
[4.0, 2.0, 2.0]
|
||||
], dtype="float32")
|
||||
np.testing.assert_allclose(output.data, expected_output, atol=1e-5)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -0,0 +1,63 @@
|
|||
import unittest
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TestFinfo(unittest.TestCase):
|
||||
def test(self):
|
||||
for dtype in ['float16', 'float32', 'float64']:
|
||||
finfo = jt.finfo(dtype)
|
||||
np_finfo = np.finfo(dtype)
|
||||
assert finfo.bits == np_finfo.bits
|
||||
assert finfo.eps == np_finfo.eps
|
||||
assert finfo.max == np_finfo.max
|
||||
assert finfo.min == np_finfo.min
|
||||
assert finfo.nexp == np_finfo.nexp
|
||||
assert finfo.nmant == np_finfo.nmant
|
||||
assert finfo.iexp == np_finfo.iexp
|
||||
assert finfo.precision == np_finfo.precision
|
||||
assert finfo.resolution == np_finfo.resolution
|
||||
assert finfo.tiny == np_finfo.tiny
|
||||
for dtype_jt, dtype in [
|
||||
(jt.float16, 'float16'),
|
||||
(jt.float32, 'float32'),
|
||||
(jt.float64, 'float64'),
|
||||
]:
|
||||
finfo = jt.finfo(dtype_jt)
|
||||
np_finfo = np.finfo(dtype)
|
||||
assert finfo.bits == np_finfo.bits
|
||||
assert finfo.eps == np_finfo.eps
|
||||
assert finfo.max == np_finfo.max
|
||||
assert finfo.min == np_finfo.min
|
||||
assert finfo.nexp == np_finfo.nexp
|
||||
assert finfo.nmant == np_finfo.nmant
|
||||
assert finfo.iexp == np_finfo.iexp
|
||||
assert finfo.precision == np_finfo.precision
|
||||
assert finfo.resolution == np_finfo.resolution
|
||||
assert finfo.tiny == np_finfo.tiny
|
||||
|
||||
|
||||
class TestIinfo(unittest.TestCase):
|
||||
def test(self):
|
||||
for dtype in ['int16', 'int32', 'int64']:
|
||||
iinfo = jt.iinfo(dtype)
|
||||
np_iinfo = np.iinfo(dtype)
|
||||
assert iinfo.bits == np_iinfo.bits
|
||||
assert iinfo.max == np_iinfo.max
|
||||
assert iinfo.min == np_iinfo.min
|
||||
assert iinfo.dtype == np.dtype(dtype)
|
||||
for dtype_jt, dtype in [
|
||||
(jt.int16, 'int16'),
|
||||
(jt.int32, 'int32'),
|
||||
(jt.int64, 'int64'),
|
||||
]:
|
||||
iinfo = jt.iinfo(dtype_jt)
|
||||
np_iinfo = np.iinfo(dtype)
|
||||
assert iinfo.bits == np_iinfo.bits
|
||||
assert iinfo.max == np_iinfo.max
|
||||
assert iinfo.min == np_iinfo.min
|
||||
assert iinfo.dtype == np.dtype(dtype)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -46,6 +46,33 @@ class TestVarFunctions(unittest.TestCase):
|
|||
np.testing.assert_allclose(jt_x.norm(2,1).numpy(), tc_x.norm(2,1).numpy(), atol=1e-6)
|
||||
np.testing.assert_allclose(jt_x.norm(2,0).numpy(), tc_x.norm(2,0).numpy(), atol=1e-6)
|
||||
|
||||
def test_std_with_dim(self):
|
||||
x=np.random.randn(100, 1000).astype(np.float32)
|
||||
jt_x = jt.array(x)
|
||||
tc_x = torch.from_numpy(x)
|
||||
np.testing.assert_allclose(jt_x.std(dim=-1).numpy(), tc_x.std(dim=-1).numpy(), 1e-4)
|
||||
np.testing.assert_allclose(jt_x.std(dim=0, keepdim=True).numpy(), tc_x.std(dim=0, keepdim=True).numpy(), 1e-4)
|
||||
|
||||
def test_diagonal(self):
|
||||
x = np.reshape(np.arange(5*6*7*8), (5,6,7,8))
|
||||
jt_x = jt.array(x)
|
||||
tc_x = torch.from_numpy(x)
|
||||
def __assert_equal(a:np.ndarray, b:np.ndarray, rtol=1e-6, atol=1e-6):
|
||||
assert a.shape == b.shape, f"{a.shape}!={b.shape}"
|
||||
np.testing.assert_allclose(a, b, rtol=rtol, atol=atol)
|
||||
__assert_equal(jt.misc.diagonal(jt_x, 0, dim1=1, dim2=2).numpy(), tc_x.diagonal(offset=0, dim1=1, dim2=2).numpy())
|
||||
__assert_equal(jt.misc.diagonal(jt_x, -1, dim1=1, dim2=2).numpy(), tc_x.diagonal(offset=-1, dim1=1, dim2=2).numpy())
|
||||
__assert_equal(jt.misc.diagonal(jt_x, -2, dim1=1, dim2=2).numpy(), tc_x.diagonal(offset=-2, dim1=1, dim2=2).numpy())
|
||||
__assert_equal(jt.misc.diagonal(jt_x, -6, dim1=1, dim2=2).numpy(), tc_x.diagonal(offset=-6, dim1=1, dim2=2).numpy())
|
||||
__assert_equal(jt.misc.diagonal(jt_x, 1, dim1=1, dim2=2).numpy(), tc_x.diagonal(offset=1, dim1=1, dim2=2).numpy())
|
||||
__assert_equal(jt.misc.diagonal(jt_x, 2, dim1=1, dim2=2).numpy(), tc_x.diagonal(offset=2, dim1=1, dim2=2).numpy())
|
||||
__assert_equal(jt.misc.diagonal(jt_x, 7, dim1=1, dim2=2).numpy(), tc_x.diagonal(offset=7, dim1=1, dim2=2).numpy())
|
||||
__assert_equal(jt.misc.diagonal(jt_x, 1, dim1=-1, dim2=2).numpy(), tc_x.diagonal(offset=1, dim1=-1, dim2=2).numpy())
|
||||
__assert_equal(jt.misc.diagonal(jt_x, 1, dim1=-1, dim2=-2).numpy(), tc_x.diagonal(offset=1, dim1=-1, dim2=-2).numpy())
|
||||
__assert_equal(jt.misc.diagonal(jt_x, 1, dim1=-2, dim2=-1).numpy(), tc_x.diagonal(offset=1, dim1=-2, dim2=-1).numpy())
|
||||
__assert_equal(jt.misc.diagonal(jt_x, 1, dim1=0, dim2=-2).numpy(), tc_x.diagonal(offset=1, dim1=0, dim2=-2).numpy())
|
||||
__assert_equal(jt.misc.diagonal(jt_x, 1, dim1=2, dim2=1).numpy(), tc_x.diagonal(offset=1, dim1=2, dim2=1).numpy())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
Loading…
Reference in New Issue