fix master

This commit is contained in:
lidongyang 2025-07-28 18:33:35 +08:00
parent c78db2a794
commit 4017b161d2
49 changed files with 503 additions and 6921 deletions

View File

@ -1,430 +0,0 @@
# import os
# os.environ["FIX_TORCH_ERROR"] = "0"
# import jittor as jt
# from jittor import *
# from typing import Tuple
# org_int = int = type(1)
# org_float = float = type(1.0)
# org_bool = bool = type(True)
# import jtorch.compiler
# import jtorch_core
# from jtorch_core import *
# device.__reduce__ = lambda self: (device, (self.type,))
# device.__module__ = "jtorch"
# jt.jittor_core.device = device
# def handle_dtype(args, kw, dtype):
# def convert(x):
# if isinstance(x, jt.Var):
# return x.cast(dtype)
# return x
# if dtype is not None:
# if args is not None:
# if isinstance(args, (tuple,list)):
# args = [ convert(a) for a in args ]
# else:
# args = convert(x)
# if kw is not None:
# kw = { k:convert(v) for k,v in kw.items() }
# return args, kw
# def get_args_names(func):
# import inspect
# spec = inspect.getfullargspec(func)
# return spec[0] + spec[4]
# def wrapper(func):
# has_dtype = False
# if hasattr(func, "__code__"):
# has_dtype = "dtype" in get_args_names(func)
# def inner(*args, **kw):
# requires_grad = None
# dtype = None
# if "requires_grad" in kw:
# requires_grad = kw["requires_grad"]
# del kw["requires_grad"]
# if not has_dtype and "dtype" in kw:
# dtype = kw["dtype"]
# del kw["dtype"]
# if "device" in kw:
# del kw["device"]
# if 'pin_memory' in kw:
# del kw['pin_memory']
# args, kw = handle_dtype(args, kw, dtype)
# ret = func(*args, **kw)
# if isinstance(ret, jt.Var):
# if requires_grad is not None:
# ret.requires_grad = requires_grad
# if dtype is not None:
# ret.astype(dtype)
# return ret
# return inner
# import inspect
# _wrapper_keys = set(["shape", "start", "size"])
# _wrapper_keys.add("x")
# for k,v in list(globals().items()):
# if callable(v) and not isinstance(v, type):
# try:
# spec = inspect.getfullargspec(v)
# args_name = spec[0]
# if len(args_name) and args_name[0] in _wrapper_keys:
# globals()[k] = wrapper(v)
# elif spec.varargs in _wrapper_keys:
# globals()[k] = wrapper(v)
# except:
# pass
# def empty(*size, dtype=jt.float32, device=None, requires_grad=False):
# if len(size) == 1 and not isinstance(size[0], org_int):
# size = size[0]
# return jt.empty(size, dtype)
# Tensor = Var
# Tensor.backward = lambda x: jtorch_core.backward(x)
# Tensor.grad = property(grad_get, grad_set, grad_del)
# Tensor.retains_grad = property(retain_grad_get, retain_grad_set)
# def retain_grad(x:Tensor, value:bool=True):
# x.retains_grad = value
# return value
# Tensor.retain_grad = retain_grad
# Tensor.dim = lambda self: self.ndim
# Tensor.ndimension = lambda self: self.ndim
# Tensor.nelement = lambda self: self.numel()
# Tensor.cuda = lambda self: self
# def device_get(x:Tensor):
# return device("cpu") if not jt.has_cuda or not jt.flags.use_cuda else device("cuda")
# Tensor.device = property(device_get)
# def argmax(x: Var, dim=None, keepdim: bool = False):
# return jt.argmax(x, dim, keepdim)[0]
# Tensor.argmax = argmax
# def tensor_type(x: Var, dtype=None, **kwargs):
# if dtype:
# return x.astype(dtype)
# else:
# return x.dtype
# Tensor.type = tensor_type
# def is_floating_point(x: Var):
# return "float" in str(x.dtype)
# Tensor.is_floating_point = is_floating_point
# from . import autograd
# from .autograd import *
# def tensor(data, *, dtype=None, device=None, requires_grad=False, pin_memory=False):
# if isinstance(data,list):
# data_list = []
# check = True
# for p in data:
# if isinstance(p, Tensor) and p.numel()==1:
# data_list.append(p.item())
# elif isinstance(p, (org_int,org_float)):
# data_list.append(p)
# else:
# check = False
# break
# if check:
# data = data_list
# return wrapper(array)(data, dtype=dtype, device=device, requires_grad=requires_grad, pin_memory=pin_memory)
# # tensor = wrapper(array)
# from_numpy = wrapper(array)
# strided = None
# def mod_zero_grad(self):
# for p in self.parameters():
# p.grad = None
# Module.zero_grad = mod_zero_grad
# class ModuleMisc:
# def parameters(self):
# return iter(super().parameters())
# def load_state_dict(self, state_dict, strict=False):
# return super().load_state_dict(state_dict)
# def to(self, device=None,dtype=None):
# ''' do nothing but return its self'''
# return self
# def register_parameter(self,name,data):
# self.name = data
# def buffers(self):
# for _, buf in self.named_buffers():
# yield buf
# def make_module(cls):
# class TMod(ModuleMisc, cls):
# def __init__(self, *args, **kw):
# dtype = None
# if "dtype" in kw:
# dtype = kw["dtype"]
# del kw["dtype"]
# self._dtype = dtype
# with jt.flag_scope(th_mode=0):
# if "device" in kw:
# del kw["device"]
# super().__init__(*args, **kw)
# for k,v in self.__dict__.items():
# if not k.startswith("_") and isinstance(v, Var) \
# and v.requires_grad:
# v.retain_grad()
# if dtype is not None and isinstance(v, Var):
# v.assign(v.cast(dtype))
# def __call__(self, *args, **kw):
# args, kw = handle_dtype(args, kw, self._dtype)
# # if forward is override by user, call forward
# if self.__class__.forward is not TMod.forward:
# return self.forward(*args, **kw)
# return self.execute(*args, **kw)
# def forward(self, *args, **kw):
# args, kw = handle_dtype(args, kw, self._dtype)
# return self.execute(*args, **kw)
# @property
# def training(self):
# if not hasattr(self, "is_train"):
# self.is_train = True
# return self.is_train
# @training.setter
# def training(self, value):
# self.is_train = value
# TMod.__name__ = cls.__name__
# return TMod
# import jtorch.cuda
# import jtorch.nn
# from jtorch.nn import Module, Parameter
# import jtorch.optim
# from jtorch.utils.dtype import Dtype, get_string_dtype
# def frombuffer(buffer: bytearray,
# *,
# dtype: Dtype,
# count: int = -1,
# offset: int = 0,
# requires_grad: bool = True) -> Tensor:
# dtype = get_string_dtype(dtype)
# tensor = jt.array(np.frombuffer(buffer, dtype, count=count, offset=offset))
# if requires_grad and tensor.dtype.is_float():
# tensor.requires_grad = True
# return tensor
# def conflict_wrapper(origin_func, new_func):
# def wrapper(*args, **kw):
# if jt.flags.th_mode:
# return new_func(*args, **kw)
# else:
# return origin_func(*args, **kw)
# return wrapper
# def min(*args, **kw):
# dim = None
# if len(args) >= 2 and isinstance(args[1], org_int):
# dim = args[1]
# elif "dim" in kw and isinstance(kw["dim"], org_int):
# dim = kw["dim"]
# if dim is not None:
# k, v = jt.argmin(*args, **kw)
# return v, k
# elif len(args) == 2 and isinstance(args[1], jt.Var):
# return jt.minimum(args[0], args[1])
# else:
# return jt.min(*args, **kw)
# Tensor.min = conflict_wrapper(jt.min, min)
# def max(*args, **kw):
# dim = None
# if "dim" in kw:
# x = kw["dim"]
# if len(args) >= 2 and isinstance(args[1], org_int):
# dim = args[1]
# elif "dim" in kw and isinstance(kw["dim"], org_int):
# dim = kw["dim"]
# if dim is not None:
# k, v = jt.argmax(*args, **kw)
# return v, k
# elif len(args) == 2 and isinstance(args[1], jt.Var):
# return jt.maximum(args[0], args[1])
# else:
# return jt.max(*args, **kw)
# Tensor.max = conflict_wrapper(jt.max, max)
# def argsort(*args, **kw):
# k, v = jt.argsort(*args, **kw)
# return k
# Tensor.argsort = conflict_wrapper(jt.argsort, argsort)
# LongTensor = jt.int64
# FloatTensor = jt.float
# HalfTensor = jt.float16
# BoolTensor = jt.bool
# IntTensor = jt.int32
# class JDType:
# def __init__(self, func, str):
# self.func = func
# self.str = str
# self.__name__ = str.split(".")[-1]
# def __call__(self, *args, **kw):
# return self.func(*args, **kw)
# def __str__(self):
# return self.str
# def is_floating_point(self):
# return "float" in str(self.str)
# int8 = JDType(jt.int8, "torch.int8")
# int16 = JDType(jt.int16, "torch.int16")
# int = int32 = JDType(jt.int32, "torch.int32")
# long = int64 = JDType(jt.int64, "torch.int64")
# half = float16 = JDType(jt.float16, "torch.float16")
# float = float32 = JDType(jt.float32, "torch.float32")
# double = float64 = JDType(jt.float64, "torch.float64")
# bfloat16 = "bfloat16" # TODO
# complex64 = "complex64" # TODO
# complex128 = "complex128" # TODO
# def get_JDtype(dtype):
# if dtype=='float32' or dtype == jt.float32:
# return float32
# elif dtype=='float64' or dtype == jt.float64:
# return float64
# elif dtype=='float16' or dtype == jt.float16:
# return float16
# elif dtype=='int32' or dtype == jt.int32:
# return int32
# elif dtype=='int64' or dtype == jt.int64:
# return int64
# elif dtype=='int16' or dtype == jt.int16:
# return int16
# elif dtype=='int8' or dtype == jt.int8:
# return int8
# else:
# raise Exception("dtype {} not supported".format(dtype))
# def load(path,**kwargs):
# def _to_jittor(data):
# if isinstance(data,dict):
# return {k:_to_jittor(d) for k,d in data.items()}
# if isinstance(data,list):
# return [_to_jittor(d) for d in data]
# if isinstance(data,np.ndarray):
# return jt.array(data)
# return data
# data = jt.load(path)
# return _to_jittor(data)
# def is_tensor(x):
# return isinstance(x, Tensor)
# manual_seed = jt.set_global_seed
# jt.flags.amp_level = 3
# Size = jt.NanoVector
# class Generator:
# def __init__(self,*args,**kw) -> None:
# self.seed = None
# def manual_seed(self,seed):
# self.seed = seed
# from . import fx
# _default_type = "float32"
# def get_default_dtype():
# return _default_type
# def set_default_dtype(dtype):
# global _default_type
# _default_type = dtype
# dtype = JDType
# def div(x,y,rounding_mode="floor"):
# assert rounding_mode == "floor"
# z = (x / y)
# if rounding_mode == "floor":
# z = z.floor()
# if x.dtype == "int32" and (isinstance(y,org_int) or y.dtype == "int32"):
# z = z.int32()
# return z
# def randn(*args,**kw):
# wrap_randn = wrapper(jt.randn)
# generator = kw.get('generator',None)
# kw.pop('generator',None)
# if 'layout' in kw:
# del kw['layout']
# if generator is not None and generator.seed is not None:
# jt.set_global_seed(generator.seed)
# return wrap_randn(*args,**kw)
# def rand(*args,**kw):
# print("rand")
# wrap_rand = wrapper(jt.rand)
# generator = kw.get('generator',None)
# kw.pop('generator',None)
# if 'layout' in kw:
# del kw['layout']
# if generator is not None and generator.seed is not None:
# jt.set_global_seed(generator.seed)
# return wrap_rand(*args,**kw)
# def set_default_tensor_type(t: type or str):
# if isinstance(t, str):
# info = t.split(".")
# if len(info) == 3 and info[1] == 'cuda':
# jt.flags.use_cuda = 1
# #TODO: type
# def clamp(x, min=None, max=None):
# return jt.clamp(x, min, max)
# def to(x,*args,**kw):
# device = None
# if len(args) == 1:
# device = args[0]
# if isinstance(device, jt.NanoString) or callable(device):
# return jt.to(x,*args,**kw)
# if 'cpu' in str(device):
# args = []
# device = kw.get("device",None)
# if 'cpu' in str(device):
# kw.pop('device',None)
# print("to cpu")
# # print(kw)
# return jt.to(x,*args,**kw)
# Tensor.to = conflict_wrapper(jt.to, to)
# mm = wrapper(jt.matmul)
# def _data_get(x):
# return x
# def _data_set(x, value):
# x.assign(value)
# Tensor.data = property(_data_get, _data_set)
# Tensor.layout = None

View File

@ -1,134 +0,0 @@
import jittor as jt
from jittor import Var
from collections.abc import Sequence, Mapping
Variable = Var
class FunctionContext:
def save_for_backward(self, *args):
self.saved_tensors = args
class Function:
''' Function Module for customized backward operations
Example 1 (Function can have multiple input and multiple output, and user
can store value for backward computation)::
import jtorch
from jtorch import Function
class MyFunc(Function):
@staticmethod
def forward(self, x, y):
self.x = x
self.y = y
return x*y, x/y
@staticmethod
def backward(self, grad0, grad1):
return grad0 * self.y, grad1 * self.x
a = jtorch.array(3.0)
a.requires_grad = True
b = jtorch.array(4.0)
b.requires_grad = True
func = MyFunc.apply
c,d = func(a, b)
(c+d*3).backward()
assert a.grad.data == 4
assert b.grad.data == 9
Example 2(Function can return None for no gradiant, and gradiant
can also be None)::
import jtorch
from jtorch import Function
class MyFunc(Function):
@staticmethod
def forward(self, x, y):
self.x = x
self.y = y
return x*y, x/y
@staticmethod
def backward(self, grad0, grad1):
assert grad1 is None
return grad0 * self.y, None
a = jt.array(3.0)
a.requires_grad = True
b = jt.array(4.0)
b.requires_grad = True
func = MyFunc.apply
c,d = func(a, b)
d.stop_grad()
da, db = jt.grad(c+d*3, [a, b])
assert da.data == 4
assert db.data == 0
'''
def __call__(self, *args):
backup = args
args = list(args)
taped_inputs = []
taped_outputs = []
input_mask = [-1] * len(args)
for i,v in enumerate(args):
if isinstance(v, Var):
if v.is_stop_grad():
# -2 in input_mask represents it is stop_grad
input_mask[i] = -2
continue
v = v.tape()
input_mask[i] = len(taped_inputs)
args[i] = v
taped_inputs.append(v)
ctx = FunctionContext()
ori_res = self.forward(ctx, *args)
# ori_res = self.execute(*args)
if not isinstance(ori_res, Sequence):
res = [ori_res]
else:
res = list(ori_res)
output_mask = [-1] * len(res)
for i,v in enumerate(res):
if isinstance(v, Var):
v = v.tape()
output_mask[i] = len(taped_outputs)
res[i] = v
taped_outputs.append(v)
ctx.input_mask = input_mask
ctx.output_mask = output_mask
# tape output and input together so
# backward treat them as one operator
jt.tape_together(taped_inputs, taped_outputs,
lambda *args: self._grad(ctx, self, *args))
if isinstance(ori_res, Sequence):
return res
else:
return res[0]
@staticmethod
def _grad(ctx, func, *args):
new_args = ( (args[i] if i>=0 else None) for i in ctx.output_mask )
ret = func.backward(ctx, *new_args)
if not isinstance(ret, Sequence):
ret = (ret,)
new_ret = []
for i, r in enumerate(ret):
j = ctx.input_mask[i]
if j<0:
# -2 in input_mask represents it is stop_grad
assert r is None or j==-2, f"{type(self)}'s {i}-th returned grad should be None, "\
"because the input value is not jittor variable."
else:
new_ret.append(r)
return new_ret
def dfs(self, parents, k, callback, callback_leave=None):
pass
@classmethod
def apply(cls, *args, **kw):
func = cls()
return func(*args, **kw)

View File

@ -1,39 +0,0 @@
import jittor as jt
import jittor_utils
import glob
import os
from jittor import pyjt_compiler
import sys
from jittor_utils import lock
jtorch_path = os.path.dirname(__file__)
cache_path = os.path.join(jt.compiler.cache_path, "jtorch")
# os.makedirs(cache_path, exist_ok=True)
os.makedirs(os.path.join(cache_path, "gen"), exist_ok=True)
with lock.lock_scope():
pyjt_gen_src = pyjt_compiler.compile(cache_path, jtorch_path)
ext_args = 'c[cu]' if jt.has_cuda else 'cc'
files = glob.glob(jtorch_path+"/src/**/*."+ext_args, recursive=True)
files += pyjt_gen_src
cc_flags = " -I\""+os.path.join(jtorch_path, "src")+"\" "
if os.environ.get("use_data_o", "1") == "1":
files += glob.glob(jtorch_path+"/src/**/*.o", recursive=True)
files = [f for f in files if "__data__" not in f]
with lock.lock_scope():
jt.compiler.compile(
jt.compiler.cc_path,
jt.compiler.cc_flags+jt.compiler.opt_flags+ cc_flags,
files,
"jtorch_core"+jt.compiler.extension_suffix,
obj_dirname="jtorch_objs")
with jittor_utils.import_scope(jt.compiler.import_flags):
import jtorch_core as core
jt.flags.th_mode = 1

View File

@ -1,64 +0,0 @@
import jittor as jt
import jtorch
def is_available():
return jt.has_cuda
def device_count():
return int(jt.has_cuda)
def set_device(device=None):
pass
def get_rng_state(device=None):
pass
def current_device():
return jtorch.device("cuda")
def mem_get_info(i):
return ("75GB",)
class Generator:
def __init__(self):
pass
def set_state(self, state):
self.state = state
default_generators = [Generator()]
_lazy_call = lambda func: func()
device = None
LongTensor = jt.int64
FloatTensor = jt.float
HalfTensor = jt.float16
BoolTensor = jt.bool
manual_seed = jt.set_global_seed
manual_seed_all = jt.set_global_seed
def synchronize():
jt.sync_all(True)
class Event:
pass
class Stream:
pass
from typing import Any
from .gradscaler import GradScaler
class autocast:
def __init__(self,**kwargs):
pass
def __enter__(self,):
pass
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any):
pass

View File

@ -1,53 +0,0 @@
import datetime
from enum import Enum
import jittor as jt
class DistributedDataParallel:
def __new__(cls, model):
return model
def is_initialized():
return True
def get_rank(group=None):
return 0
def get_world_size(group=None):
return 1
def get_backend(group=None):
return "nccl"
def new_group(ranks=None, timeout=datetime.timedelta(seconds=1800), backend=None, pg_options=None):
return 1
def barrier():
pass
def is_available():
return True
def is_built():
return True
class ReduceOp:
SUM = 0
class GroupMember:
WORLD = 0
class ProcessGroup:
pass
class Join:
pass
dist_backend = Enum("dist_backend", ("GLOO", "MPI", "NCCL"))
_backend = dist_backend.NCCL
def is_mpi_available():
return jt.in_mpi
def DistributedDataParallel(model, *args, **kw):
return model

View File

@ -1,15 +0,0 @@
import jittor as jt
class RelaxedBernoulli:
def __init__(self, temperature, probs=None, logits=None):
self.temperature = temperature
self.probs = probs
self.logits = logits
def rsample(self):
noise = jt.rand_like(self.logits)
eps = 1e-20
noise = jt.clamp(noise, eps, 1.0 - eps)
logit_noise = jt.log(noise) - jt.log(1 - noise)
sample = (self.logits + logit_noise) / self.temperature
return jt.sigmoid(sample)

View File

@ -1,5 +0,0 @@
#TODO: Implement FFT and IFFT
fftn = None
fftshift = None
ifftn = None
ifftshift = None

View File

@ -1,2 +0,0 @@
class Proxy:
pass

View File

@ -1,519 +0,0 @@
from collections import defaultdict, abc
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, cast
import inspect
import warnings
import jittor as jt
# import torch
def _refresh_per_optimizer_state():
return {}
class GradScaler:
_scale: Optional[jt.Var]
_grows_tracker: Optional[jt.Var]
_per_optimizer_states: Dict[int, Dict[str, Any]]
"""
An instance ``scaler`` of :class:`GradScaler` helps perform the steps of gradient scaling
conveniently.
* ``scaler.scale(loss)`` multiplies a given loss by ``scaler``'s current scale factor.
* ``scaler.step(optimizer)`` safely unscales gradients and calls ``optimizer.step()``.
* ``scaler.update()`` updates ``scaler``'s scale factor.
Example::
# Creates a GradScaler once at the beginning of training.
scaler = GradScaler()
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
# Scales loss. Calls backward() on scaled loss to create scaled gradients.
scaler.scale(loss).backward()
# scaler.step() first unscales gradients of the optimizer's params.
# If gradients don't contain infs/NaNs, optimizer.step() is then called,
# otherwise, optimizer.step() is skipped.
scaler.step(optimizer)
# Updates the scale for next iteration.
scaler.update()
See the :ref:`Automatic Mixed Precision examples<amp-examples>` for usage
(along with autocasting) in more complex cases like gradient clipping, gradient accumulation, gradient penalty,
and multiple losses/optimizers.
``scaler`` dynamically estimates the scale factor each iteration. To minimize gradient underflow,
a large scale factor should be used. However, ``float16`` values can "overflow" (become inf or NaN) if
the scale factor is too large. Therefore, the optimal scale factor is the largest factor that can be used
without incurring inf or NaN gradient values.
``scaler`` approximates the optimal scale factor over time by checking the gradients for infs and NaNs during every
``scaler.step(optimizer)`` (or optional separate ``scaler.unscale_(optimizer)``, see :meth:`unscale_`).
* If infs/NaNs are found, ``scaler.step(optimizer)`` skips the underlying ``optimizer.step()`` (so the params
themselves remain uncorrupted) and ``update()`` multiplies the scale by ``backoff_factor``.
* If no infs/NaNs are found, ``scaler.step(optimizer)`` runs the underlying ``optimizer.step()`` as usual.
If ``growth_interval`` unskipped iterations occur consecutively, ``update()`` multiplies the scale by
``growth_factor``.
The scale factor often causes infs/NaNs to appear in gradients for the first few iterations as its
value calibrates. ``scaler.step`` will skip the underlying ``optimizer.step()`` for these
iterations. After that, step skipping should occur rarely (once every few hundred or thousand iterations).
Args:
init_scale (float, optional, default=2.**16): Initial scale factor.
growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during
:meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations.
backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during
:meth:`update` if inf/NaN gradients occur in an iteration.
growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients
that must occur for the scale to be multiplied by ``growth_factor``.
enabled (bool, optional): If ``False``, disables gradient scaling. :meth:`step` simply
invokes the underlying ``optimizer.step()``, and other methods become no-ops.
Default: ``True``
"""
def __init__(self,
init_scale=2.**16,
growth_factor=2.0,
backoff_factor=0.5,
growth_interval=2000,
enabled=True):
self._enabled = enabled
if self._enabled:
assert growth_factor > 1.0, "The growth factor must be > 1.0."
assert backoff_factor < 1.0, "The backoff factor must be < 1.0."
self._init_scale = init_scale
# self._scale will be lazily initialized during the first call to scale()
self._scale = None
self._growth_factor = growth_factor
self._backoff_factor = backoff_factor
self._growth_interval = growth_interval
self._init_growth_tracker = 0
# self._growth_tracker will be lazily initialized during the first call to scale()
self._growth_tracker = None
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
def _check_scale_growth_tracker(self, funcname) -> Tuple[jt.Var, jt.Var]:
fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration."
assert self._scale is not None, "Attempted {} but _scale is None. ".format(funcname) + fix
assert self._growth_tracker is not None, "Attempted {} but _growth_tracker is None. ".format(funcname) + fix
return (self._scale, self._growth_tracker)
def _lazy_init_scale_growth_tracker(self):
assert self._growth_tracker is None, "_growth_tracker initialized before _scale"
self._scale = self._init_scale
self._growth_tracker = self._init_growth_tracker
def scale(self, outputs):
"""
Multiplies ('scales') a tensor or list of tensors by the scale factor.
Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned
unmodified.
Args:
outputs (Tensor or iterable of Tensors): Outputs to scale.
"""
if not self._enabled:
return outputs
# Short-circuit for the common case.
if isinstance(outputs, jt.Var):
assert jt.flags.use_cuda == 1
if self._scale is None:
self._lazy_init_scale_growth_tracker()
assert self._scale is not None
return outputs * self._scale
def apply_scale(val):
if isinstance(val, jt.Var):
assert jt.flags.use_cuda == 1
if self._scale is None:
self._lazy_init_scale_growth_tracker()
assert self._scale is not None
return val * self._scale
elif isinstance(val, abc.Iterable):
iterable = map(apply_scale, val)
if isinstance(val, (list, tuple)):
return type(val)(iterable)
else:
return iterable
else:
raise ValueError("outputs must be a Tensor or an iterable of Tensors")
return apply_scale(outputs)
def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16):
with jt.no_grad():
optimizer.pre_step()
for group in optimizer.param_groups:
for to_unscale in group["grads"]:
if to_unscale is None or isinstance(to_unscale,(int,float)):
continue
if (not allow_fp16) and str(to_unscale.dtype) == "float16":
raise ValueError("Attempting to unscale FP16 gradients.")
if not (to_unscale.isinf().any()):
if inv_scale != 1.0:
to_unscale.update(to_unscale*inv_scale)
else:
found_inf = 1.0
return found_inf
def unscale_(self, optimizer):
"""
Divides ("unscales") the optimizer's gradient tensors by the scale factor.
:meth:`unscale_` is optional, serving cases where you need to
:ref:`modify or inspect gradients<working-with-unscaled-gradients>`
between the backward pass(es) and :meth:`step`.
If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`.
Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients::
...
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
scaler.step(optimizer)
scaler.update()
Args:
optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled.
.. note::
:meth:`unscale_` does not incur a CPU-GPU sync.
.. warning::
:meth:`unscale_` should only be called once per optimizer per :meth:`step` call,
and only after all gradients for that optimizer's assigned parameters have been accumulated.
Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError.
.. warning::
:meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute.
"""
if not self._enabled:
return
self._check_scale_growth_tracker("unscale_")
optimizer_state = self._per_optimizer_states[id(optimizer)]
if hasattr(optimizer,"get_find_inf"):
return
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
assert self._scale is not None
inv_scale = 1.0 / self._scale
found_inf = 0.0
optimizer_state["found_inf_per_device"] = self._unscale_grads_(optimizer, inv_scale, found_inf, False)
def step(self, optimizer, *args, **kwargs):
"""
:meth:`step` carries out the following two operations:
1. Internally invokes ``unscale_(optimizer)`` (unless :meth:`unscale_` was explicitly called for ``optimizer``
earlier in the iteration). As part of the :meth:`unscale_`, gradients are checked for infs/NaNs.
2. If no inf/NaN gradients are found, invokes ``optimizer.step()`` using the unscaled
gradients. Otherwise, ``optimizer.step()`` is skipped to avoid corrupting the params.
``*args`` and ``**kwargs`` are forwarded to ``optimizer.step()``.
Returns the return value of ``optimizer.step(*args, **kwargs)``.
Args:
optimizer (torch.optim.Optimizer): Optimizer that applies the gradients.
args: Any arguments.
kwargs: Any keyword arguments.
.. warning::
Closure use is not currently supported.
"""
if (not self._enabled):
return optimizer.step(*args, **kwargs)
if "closure" in kwargs:
raise RuntimeError("Closure use is not currently supported if GradScaler is enabled.")
self._check_scale_growth_tracker("step")
optimizer_state = self._per_optimizer_states[id(optimizer)]
retval = None
if (hasattr(optimizer, "_step_supports_amp_scaling") and optimizer._step_supports_amp_scaling):
# This optimizer has customized scale-handling logic, so we can call optimizer.step() directly.
# The contract with custom optimizers is that their step() should accept an additional,
# optional grad_scaler kwarg. We append self to the kwargs so the custom optimizer has full information:
# it can query its own state, invoke unscale_ on itself, etc
# The contract above is being deprecated to avoid introducing `grad_scaler: GradScaler` argument
# to `Optimizer.step`. The new behavior is going to add two Tensor attributes of `grad_scale`
# and `found_inf` to the passed optimizer so that the optimizer can utilize those
# to skip the parameter updates or unscale gradients before updating parameters in
# the fused kernel, e.g. `FusedAdamMathFunctor`.
# In this behavior, `GradScaler._check_inf_per_device` is called if `OptState.READY`,
# while the method is expected to be called by users side, i.e. their optimizers.
kwargs_ = kwargs
has_grad_scaler_kwarg = "grad_scaler" in inspect.signature(optimizer.step).parameters
if has_grad_scaler_kwarg:
warnings.warn(
"GradScaler is going to stop passing itself as a keyword argument to the passed "
"optimizer. In the near future GradScaler registers `grad_scale: Tensor` and "
"`found_inf: Tensor` to the passed optimizer and let the optimizer use them directly.",
FutureWarning)
kwargs_.update({"grad_scaler": self})
else:
if optimizer_state["stage"] is OptState.READY:
self._check_inf_per_device(optimizer)
scaler = self._get_scale_async()
found_inf = cast(
jt.Var,
sum([
t for t in optimizer_state["found_inf_per_device"].values()
])
)
optimizer.grad_scale = None if optimizer_state["stage"] == OptState.UNSCALED else scaler
optimizer.found_inf = found_inf
retval = optimizer.step(*args, **kwargs_)
optimizer_state["stage"] = OptState.STEPPED
if not has_grad_scaler_kwarg:
del optimizer.grad_scale
del optimizer.found_inf
return retval
if hasattr(optimizer,"get_find_inf"):
optimizer.set_grad_scale(self._scale)
optimizer.step()
optimizer_state["found_inf_per_device"] = optimizer.get_find_inf()
return
retval = None
if not optimizer_state["found_inf_per_device"]:
retval = optimizer.step(*args, **kwargs)
else:
optimizer.post_step()
return retval
def update(self, new_scale=None):
"""
Updates the scale factor.
If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
the scale is multiplied by ``growth_factor`` to increase it.
Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
used directly, it's used to fill GradScaler's internal scale tensor. So if
``new_scale`` was a tensor, later in-place changes to that tensor will not further
affect the scale GradScaler uses internally.)
Args:
new_scale (float or :class:`torch.cuda.FloatTensor`, optional, default=None): New scale factor.
.. warning::
:meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
been invoked for all optimizers used this iteration.
"""
if not self._enabled:
return
_scale, _growth_tracker = self._check_scale_growth_tracker("update")
if new_scale is not None:
# Accept a new user-defined scale.
if isinstance(new_scale, float):
self._scale.fill_(new_scale) # type: ignore[union-attr]
else:
reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False."
assert isinstance(new_scale, torch.cuda.FloatTensor), reason # type: ignore[attr-defined]
assert new_scale.numel() == 1, reason
assert new_scale.requires_grad is False, reason
self._scale.copy_(new_scale) # type: ignore[union-attr]
else:
# Consume shared inf/nan data collected from optimizers to update the scale.
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
found_infs = [state["found_inf_per_device"]
for state in self._per_optimizer_states.values()
]
assert len(found_infs) > 0, "No inf checks were recorded prior to update."
found_inf_combined = found_infs[0]
if len(found_infs) > 1:
for i in range(1, len(found_infs)):
found_inf_combined += found_infs[i]
current_scale = _scale
if found_inf_combined:
current_scale *=self._backoff_factor
_growth_tracker = 0
else:
successful = _growth_tracker+1
if successful == self._growth_interval:
new_scale = current_scale*self._growth_factor
if new_scale < 1e9:
current_scale = new_scale
_growth_tracker = 0
else:
_growth_tracker = successful
self._scale, self._growth_tracker = current_scale,_growth_tracker
# To prepare for next iteration, clear the data collected from optimizers this iteration.
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
def _get_scale_async(self):
return self._scale
def get_scale(self):
"""
Returns a Python float containing the current scale, or 1.0 if scaling is disabled.
.. warning::
:meth:`get_scale` incurs a CPU-GPU sync.
"""
if self._enabled:
return self._init_scale if self._scale is None else self._get_scale_async()
else:
return 1.0
def get_growth_factor(self):
r"""
Returns a Python float containing the scale growth factor.
"""
return self._growth_factor
def set_growth_factor(self, new_factor):
r"""
Args:
new_scale (float): Value to use as the new scale growth factor.
"""
self._growth_factor = new_factor
def get_backoff_factor(self):
r"""
Returns a Python float containing the scale backoff factor.
"""
return self._backoff_factor
def set_backoff_factor(self, new_factor):
r"""
Args:
new_scale (float): Value to use as the new scale backoff factor.
"""
self._backoff_factor = new_factor
def get_growth_interval(self):
r"""
Returns a Python int containing the growth interval.
"""
return self._growth_interval
def set_growth_interval(self, new_interval):
r"""
Args:
new_interval (int): Value to use as the new growth interval.
"""
self._growth_interval = new_interval
def _get_growth_tracker(self):
if self._enabled:
return self._init_growth_tracker if self._growth_tracker is None else self._growth_tracker.item()
else:
return 0
def is_enabled(self):
r"""
Returns a bool indicating whether this instance is enabled.
"""
return self._enabled
def state_dict(self):
r"""
Returns the state of the scaler as a :class:`dict`. It contains five entries:
* ``"scale"`` - a Python float containing the current scale
* ``"growth_factor"`` - a Python float containing the current growth factor
* ``"backoff_factor"`` - a Python float containing the current backoff factor
* ``"growth_interval"`` - a Python int containing the current growth interval
* ``"_growth_tracker"`` - a Python int containing the number of recent consecutive unskipped steps.
If this instance is not enabled, returns an empty dict.
.. note::
If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict`
should be called after :meth:`update`.
"""
return {"scale": self.get_scale(),
"growth_factor": self._growth_factor,
"backoff_factor": self._backoff_factor,
"growth_interval": self._growth_interval,
"_growth_tracker": self._get_growth_tracker()} if self._enabled else {}
def load_state_dict(self, state_dict):
r"""
Loads the scaler state. If this instance is disabled, :meth:`load_state_dict` is a no-op.
Args:
state_dict(dict): scaler state. Should be an object returned from a call to :meth:`state_dict`.
"""
if not self._enabled:
return
if len(state_dict) == 0:
raise RuntimeError("The source state dict is empty, possibly because it was saved "
"from a disabled instance of GradScaler.")
self._init_scale = state_dict["scale"]
if self._scale is not None:
self._scale.fill_(state_dict["scale"])
self._growth_factor = state_dict["growth_factor"]
self._backoff_factor = state_dict["backoff_factor"]
self._growth_interval = state_dict["growth_interval"]
self._init_growth_tracker = state_dict["_growth_tracker"]
if self._growth_tracker is not None:
self._growth_tracker.fill_(state_dict["_growth_tracker"])
def __getstate__(self):
state = self.__dict__.copy()
if self._enabled:
assert len(self._per_optimizer_states) == 0, "A GradScaler instance may only be pickled at the beginning "\
"of an iteration, or at the end after scaler.update()."
# Pickling _scale and _growth_tracker Tensors directly triggers
# "warnings.warn("pickle support for Storage will be removed in 1.5..."
# so instead, we set the unpickled instance up to reinitialize them lazily.
state['_init_scale'] = self.get_scale()
state['_init_growth_tracker'] = self._get_growth_tracker()
state['_scale'] = None
state['_growth_tracker'] = None
return state
def __setstate__(self, state):
self.__dict__.update(state)
def _check_inf_per_device(self, optimizer):
_scale, _ = self._check_scale_growth_tracker("_check_inf_per_device")
dummy_inv_scale = 1.0
found_inf = 0.0
self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = \
self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True)
return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]
def _found_inf_per_device(self, optimizer):
return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]

View File

@ -1,556 +0,0 @@
from collections import defaultdict, abc
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, cast
import inspect
import warnings
import jittor as jt
# import torch
__all__ = ["OptState", "GradScaler"]
# Defines default_factory for GradScaler's _per_optimizer_states defaultdict,
# as well as associated "enum" values. Prefers defining these at top level because
# - Lambdas can't be pickled, so we don't want to supply a lambda as the factory.
# - Defining READY, UNSCALED, STEPPED and _refresh_per_optimizer_state within GradScaler
# causes a circular reference, which we'd rather avoid.
class OptState(Enum):
READY = 0
UNSCALED = 1
STEPPED = 2
def _refresh_per_optimizer_state():
return {"stage": OptState.READY, "found_inf_per_device": {}}
class GradScaler:
_scale: Optional[jt.Var]
_grows_tracker: Optional[jt.Var]
_per_optimizer_states: Dict[int, Dict[str, Any]]
"""
An instance ``scaler`` of :class:`GradScaler` helps perform the steps of gradient scaling
conveniently.
* ``scaler.scale(loss)`` multiplies a given loss by ``scaler``'s current scale factor.
* ``scaler.step(optimizer)`` safely unscales gradients and calls ``optimizer.step()``.
* ``scaler.update()`` updates ``scaler``'s scale factor.
Example::
# Creates a GradScaler once at the beginning of training.
scaler = GradScaler()
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
# Scales loss. Calls backward() on scaled loss to create scaled gradients.
scaler.scale(loss).backward()
# scaler.step() first unscales gradients of the optimizer's params.
# If gradients don't contain infs/NaNs, optimizer.step() is then called,
# otherwise, optimizer.step() is skipped.
scaler.step(optimizer)
# Updates the scale for next iteration.
scaler.update()
See the :ref:`Automatic Mixed Precision examples<amp-examples>` for usage
(along with autocasting) in more complex cases like gradient clipping, gradient accumulation, gradient penalty,
and multiple losses/optimizers.
``scaler`` dynamically estimates the scale factor each iteration. To minimize gradient underflow,
a large scale factor should be used. However, ``float16`` values can "overflow" (become inf or NaN) if
the scale factor is too large. Therefore, the optimal scale factor is the largest factor that can be used
without incurring inf or NaN gradient values.
``scaler`` approximates the optimal scale factor over time by checking the gradients for infs and NaNs during every
``scaler.step(optimizer)`` (or optional separate ``scaler.unscale_(optimizer)``, see :meth:`unscale_`).
* If infs/NaNs are found, ``scaler.step(optimizer)`` skips the underlying ``optimizer.step()`` (so the params
themselves remain uncorrupted) and ``update()`` multiplies the scale by ``backoff_factor``.
* If no infs/NaNs are found, ``scaler.step(optimizer)`` runs the underlying ``optimizer.step()`` as usual.
If ``growth_interval`` unskipped iterations occur consecutively, ``update()`` multiplies the scale by
``growth_factor``.
The scale factor often causes infs/NaNs to appear in gradients for the first few iterations as its
value calibrates. ``scaler.step`` will skip the underlying ``optimizer.step()`` for these
iterations. After that, step skipping should occur rarely (once every few hundred or thousand iterations).
Args:
init_scale (float, optional, default=2.**16): Initial scale factor.
growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during
:meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations.
backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during
:meth:`update` if inf/NaN gradients occur in an iteration.
growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients
that must occur for the scale to be multiplied by ``growth_factor``.
enabled (bool, optional): If ``False``, disables gradient scaling. :meth:`step` simply
invokes the underlying ``optimizer.step()``, and other methods become no-ops.
Default: ``True``
"""
def __init__(self,
init_scale=2.**16,
growth_factor=2.0,
backoff_factor=0.5,
growth_interval=2000,
enabled=True):
self._enabled = enabled
if self._enabled:
assert growth_factor > 1.0, "The growth factor must be > 1.0."
assert backoff_factor < 1.0, "The backoff factor must be < 1.0."
self._init_scale = init_scale
# self._scale will be lazily initialized during the first call to scale()
self._scale = None
self._growth_factor = growth_factor
self._backoff_factor = backoff_factor
self._growth_interval = growth_interval
self._init_growth_tracker = 0
# self._growth_tracker will be lazily initialized during the first call to scale()
self._growth_tracker = None
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
def _check_scale_growth_tracker(self, funcname) -> Tuple[jt.Var, jt.Var]:
fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration."
assert self._scale is not None, "Attempted {} but _scale is None. ".format(funcname) + fix
assert self._growth_tracker is not None, "Attempted {} but _growth_tracker is None. ".format(funcname) + fix
return (self._scale, self._growth_tracker)
def _lazy_init_scale_growth_tracker(self):
assert self._growth_tracker is None, "_growth_tracker initialized before _scale"
self._scale = self._init_scale
self._growth_tracker = self._init_growth_tracker
def scale(self, outputs):
"""
Multiplies ('scales') a tensor or list of tensors by the scale factor.
Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned
unmodified.
Args:
outputs (Tensor or iterable of Tensors): Outputs to scale.
"""
print("scale")
if not self._enabled:
return outputs
# Short-circuit for the common case.
if isinstance(outputs, jt.Var):
assert jt.flags.use_cuda == 1
if self._scale is None:
self._lazy_init_scale_growth_tracker()
assert self._scale is not None
return outputs * self._scale
def apply_scale(val):
if isinstance(val, jt.Var):
assert jt.flags.use_cuda == 1
if self._scale is None:
self._lazy_init_scale_growth_tracker()
assert self._scale is not None
return val * self._scale
elif isinstance(val, abc.Iterable):
iterable = map(apply_scale, val)
if isinstance(val, (list, tuple)):
return type(val)(iterable)
else:
return iterable
else:
raise ValueError("outputs must be a Tensor or an iterable of Tensors")
return apply_scale(outputs)
def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16):
# To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype.
# There could be hundreds of grads, so we'd like to iterate through them just once.
# However, we don't know their devices or dtypes in advance.
# https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
# Google says mypy struggles with defaultdicts type annotations.
with jt.no_grad():
optimizer.pre_step()
for group in optimizer.param_groups:
for to_unscale in group["grads"]:
if to_unscale is None or isinstance(to_unscale,(int,float)):
continue
if (not allow_fp16) and str(to_unscale.dtype) == "float16":
raise ValueError("Attempting to unscale FP16 gradients.")
if not (to_unscale.isinf().any()):
if inv_scale != 1.0:
to_unscale.update(to_unscale*inv_scale)
else:
found_inf = 1.0
return found_inf
def unscale_(self, optimizer):
"""
Divides ("unscales") the optimizer's gradient tensors by the scale factor.
:meth:`unscale_` is optional, serving cases where you need to
:ref:`modify or inspect gradients<working-with-unscaled-gradients>`
between the backward pass(es) and :meth:`step`.
If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`.
Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients::
...
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
scaler.step(optimizer)
scaler.update()
Args:
optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled.
.. note::
:meth:`unscale_` does not incur a CPU-GPU sync.
.. warning::
:meth:`unscale_` should only be called once per optimizer per :meth:`step` call,
and only after all gradients for that optimizer's assigned parameters have been accumulated.
Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError.
.. warning::
:meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute.
"""
if not self._enabled:
return
self._check_scale_growth_tracker("unscale_")
optimizer_state = self._per_optimizer_states[id(optimizer)]
if optimizer_state["stage"] is OptState.UNSCALED:
raise RuntimeError("unscale_() has already been called on this optimizer since the last update().")
elif optimizer_state["stage"] is OptState.STEPPED:
raise RuntimeError("unscale_() is being called after step().")
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
assert self._scale is not None
inv_scale = 1.0 / self._scale
found_inf = 0.0
optimizer_state["found_inf_per_device"] = self._unscale_grads_(optimizer, inv_scale, found_inf, False)
optimizer_state["stage"] = OptState.UNSCALED
def _maybe_opt_step(self, optimizer, optimizer_state, *args, **kwargs):
retval = None
if not optimizer_state["found_inf_per_device"]:
retval = optimizer.step(*args, **kwargs)
else:
optimizer.post_step()
return retval
def step(self, optimizer, *args, **kwargs):
"""
:meth:`step` carries out the following two operations:
1. Internally invokes ``unscale_(optimizer)`` (unless :meth:`unscale_` was explicitly called for ``optimizer``
earlier in the iteration). As part of the :meth:`unscale_`, gradients are checked for infs/NaNs.
2. If no inf/NaN gradients are found, invokes ``optimizer.step()`` using the unscaled
gradients. Otherwise, ``optimizer.step()`` is skipped to avoid corrupting the params.
``*args`` and ``**kwargs`` are forwarded to ``optimizer.step()``.
Returns the return value of ``optimizer.step(*args, **kwargs)``.
Args:
optimizer (torch.optim.Optimizer): Optimizer that applies the gradients.
args: Any arguments.
kwargs: Any keyword arguments.
.. warning::
Closure use is not currently supported.
"""
if (not self._enabled):
return optimizer.step(*args, **kwargs)
if "closure" in kwargs:
raise RuntimeError("Closure use is not currently supported if GradScaler is enabled.")
self._check_scale_growth_tracker("step")
optimizer_state = self._per_optimizer_states[id(optimizer)]
if optimizer_state["stage"] is OptState.STEPPED:
raise RuntimeError("step() has already been called since the last update().")
retval = None
if (hasattr(optimizer, "_step_supports_amp_scaling") and optimizer._step_supports_amp_scaling):
# This optimizer has customized scale-handling logic, so we can call optimizer.step() directly.
# The contract with custom optimizers is that their step() should accept an additional,
# optional grad_scaler kwarg. We append self to the kwargs so the custom optimizer has full information:
# it can query its own state, invoke unscale_ on itself, etc
# The contract above is being deprecated to avoid introducing `grad_scaler: GradScaler` argument
# to `Optimizer.step`. The new behavior is going to add two Tensor attributes of `grad_scale`
# and `found_inf` to the passed optimizer so that the optimizer can utilize those
# to skip the parameter updates or unscale gradients before updating parameters in
# the fused kernel, e.g. `FusedAdamMathFunctor`.
# In this behavior, `GradScaler._check_inf_per_device` is called if `OptState.READY`,
# while the method is expected to be called by users side, i.e. their optimizers.
kwargs_ = kwargs
has_grad_scaler_kwarg = "grad_scaler" in inspect.signature(optimizer.step).parameters
if has_grad_scaler_kwarg:
warnings.warn(
"GradScaler is going to stop passing itself as a keyword argument to the passed "
"optimizer. In the near future GradScaler registers `grad_scale: Tensor` and "
"`found_inf: Tensor` to the passed optimizer and let the optimizer use them directly.",
FutureWarning)
kwargs_.update({"grad_scaler": self})
else:
if optimizer_state["stage"] is OptState.READY:
self._check_inf_per_device(optimizer)
scaler = self._get_scale_async()
found_inf = cast(
jt.Var,
sum([
t for t in optimizer_state["found_inf_per_device"].values()
])
)
optimizer.grad_scale = None if optimizer_state["stage"] == OptState.UNSCALED else scaler
optimizer.found_inf = found_inf
retval = optimizer.step(*args, **kwargs_)
optimizer_state["stage"] = OptState.STEPPED
if not has_grad_scaler_kwarg:
del optimizer.grad_scale
del optimizer.found_inf
return retval
if optimizer_state["stage"] is OptState.READY:
self.unscale_(optimizer)
assert "found_inf_per_device" in optimizer_state, "No inf checks were recorded for this optimizer."
retval = self._maybe_opt_step(optimizer, optimizer_state, *args, **kwargs)
optimizer_state["stage"] = OptState.STEPPED
return retval
def update(self, new_scale=None):
"""
Updates the scale factor.
If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
the scale is multiplied by ``growth_factor`` to increase it.
Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
used directly, it's used to fill GradScaler's internal scale tensor. So if
``new_scale`` was a tensor, later in-place changes to that tensor will not further
affect the scale GradScaler uses internally.)
Args:
new_scale (float or :class:`torch.cuda.FloatTensor`, optional, default=None): New scale factor.
.. warning::
:meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
been invoked for all optimizers used this iteration.
"""
if not self._enabled:
return
_scale, _growth_tracker = self._check_scale_growth_tracker("update")
if new_scale is not None:
# Accept a new user-defined scale.
if isinstance(new_scale, float):
self._scale.fill_(new_scale) # type: ignore[union-attr]
else:
reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False."
assert isinstance(new_scale, torch.cuda.FloatTensor), reason # type: ignore[attr-defined]
assert new_scale.numel() == 1, reason
assert new_scale.requires_grad is False, reason
self._scale.copy_(new_scale) # type: ignore[union-attr]
else:
# Consume shared inf/nan data collected from optimizers to update the scale.
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
found_infs = [state["found_inf_per_device"]
for state in self._per_optimizer_states.values()
]
assert len(found_infs) > 0, "No inf checks were recorded prior to update."
found_inf_combined = found_infs[0]
if len(found_infs) > 1:
for i in range(1, len(found_infs)):
found_inf_combined += found_infs[i]
current_scale = _scale
if found_inf_combined:
current_scale *=self._backoff_factor
_growth_tracker = 0
else:
successful = _growth_tracker+1
if successful == self._growth_interval:
new_scale = current_scale*self._growth_factor
if new_scale < 1e9:
current_scale = new_scale
_growth_tracker = 0
else:
_growth_tracker = successful
self._scale, self._growth_tracker = current_scale,_growth_tracker
# To prepare for next iteration, clear the data collected from optimizers this iteration.
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
def _get_scale_async(self):
return self._scale
def get_scale(self):
"""
Returns a Python float containing the current scale, or 1.0 if scaling is disabled.
.. warning::
:meth:`get_scale` incurs a CPU-GPU sync.
"""
if self._enabled:
return self._init_scale if self._scale is None else self._get_scale_async()
else:
return 1.0
def get_growth_factor(self):
r"""
Returns a Python float containing the scale growth factor.
"""
return self._growth_factor
def set_growth_factor(self, new_factor):
r"""
Args:
new_scale (float): Value to use as the new scale growth factor.
"""
self._growth_factor = new_factor
def get_backoff_factor(self):
r"""
Returns a Python float containing the scale backoff factor.
"""
return self._backoff_factor
def set_backoff_factor(self, new_factor):
r"""
Args:
new_scale (float): Value to use as the new scale backoff factor.
"""
self._backoff_factor = new_factor
def get_growth_interval(self):
r"""
Returns a Python int containing the growth interval.
"""
return self._growth_interval
def set_growth_interval(self, new_interval):
r"""
Args:
new_interval (int): Value to use as the new growth interval.
"""
self._growth_interval = new_interval
def _get_growth_tracker(self):
if self._enabled:
return self._init_growth_tracker if self._growth_tracker is None else self._growth_tracker.item()
else:
return 0
def is_enabled(self):
r"""
Returns a bool indicating whether this instance is enabled.
"""
return self._enabled
def state_dict(self):
r"""
Returns the state of the scaler as a :class:`dict`. It contains five entries:
* ``"scale"`` - a Python float containing the current scale
* ``"growth_factor"`` - a Python float containing the current growth factor
* ``"backoff_factor"`` - a Python float containing the current backoff factor
* ``"growth_interval"`` - a Python int containing the current growth interval
* ``"_growth_tracker"`` - a Python int containing the number of recent consecutive unskipped steps.
If this instance is not enabled, returns an empty dict.
.. note::
If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict`
should be called after :meth:`update`.
"""
return {"scale": self.get_scale(),
"growth_factor": self._growth_factor,
"backoff_factor": self._backoff_factor,
"growth_interval": self._growth_interval,
"_growth_tracker": self._get_growth_tracker()} if self._enabled else {}
def load_state_dict(self, state_dict):
r"""
Loads the scaler state. If this instance is disabled, :meth:`load_state_dict` is a no-op.
Args:
state_dict(dict): scaler state. Should be an object returned from a call to :meth:`state_dict`.
"""
if not self._enabled:
return
if len(state_dict) == 0:
raise RuntimeError("The source state dict is empty, possibly because it was saved "
"from a disabled instance of GradScaler.")
self._init_scale = state_dict["scale"]
if self._scale is not None:
self._scale.fill_(state_dict["scale"])
self._growth_factor = state_dict["growth_factor"]
self._backoff_factor = state_dict["backoff_factor"]
self._growth_interval = state_dict["growth_interval"]
self._init_growth_tracker = state_dict["_growth_tracker"]
if self._growth_tracker is not None:
self._growth_tracker.fill_(state_dict["_growth_tracker"])
def __getstate__(self):
state = self.__dict__.copy()
if self._enabled:
assert len(self._per_optimizer_states) == 0, "A GradScaler instance may only be pickled at the beginning "\
"of an iteration, or at the end after scaler.update()."
# Pickling _scale and _growth_tracker Tensors directly triggers
# "warnings.warn("pickle support for Storage will be removed in 1.5..."
# so instead, we set the unpickled instance up to reinitialize them lazily.
state['_init_scale'] = self.get_scale()
state['_init_growth_tracker'] = self._get_growth_tracker()
state['_scale'] = None
state['_growth_tracker'] = None
return state
def __setstate__(self, state):
self.__dict__.update(state)
def _check_inf_per_device(self, optimizer):
_scale, _ = self._check_scale_growth_tracker("_check_inf_per_device")
dummy_inv_scale = 1.0
found_inf = 0.0
self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = \
self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True)
return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]
def _found_inf_per_device(self, optimizer):
return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]

View File

@ -1,12 +0,0 @@
import math
def _jit_set_profiling_mode(x): pass
def _jit_set_profiling_executor(x): pass
def _jit_override_can_fuse_on_cpu(x): pass
def _jit_override_can_fuse_on_gpu(x): pass
def script(func):
return func
inf = math.inf
nan = math.nan

View File

@ -1,281 +0,0 @@
import jtorch
from typing import List, Optional, Tuple, Iterable, Iterator, Mapping, Any, overload, TypeVar, Dict
from typing_extensions import Self
import jittor as jt
from jtorch import make_module, Tensor, ModuleMisc, wrapper
#from . import init
from jittor import Function
import operator
import warnings
for k,v in jt.nn.__dict__.items():
if callable(v):
globals()[k] = wrapper(v)
for k,v in jt.nn.__dict__.items():
if isinstance(v, type) and issubclass(v, jt.Module):
globals()[k] = make_module(v)
from collections import OrderedDict
from collections import abc as container_abcs
class Module(ModuleMisc, jt.Module):
def __call__(self, *args, **kw):
return self.execute(*args, **kw)
def execute(self, *args, **kw):
return self.forward(*args, **kw)
def get_submodule(self, target: str):
if target == "":
return self
atoms: List[str] = target.split(".")
mod: jt.nn.Module = self
for item in atoms:
if not hasattr(mod, item):
raise AttributeError(mod._get_name() + " has no "
"attribute `" + item + "`")
mod = getattr(mod, item)
if not isinstance(mod, jt.nn.Module):
raise AttributeError("`" + item + "` is not "
"an nn.Module")
return mod
def Parameter(x:Tensor, requires_grad:bool=True) -> Tensor:
x = x.clone()
x.requires_grad = requires_grad
x.retains_grad = requires_grad
return x
def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False):
return jt.nn.embedding(input, weight)
def dropout(x, p=0.5, training=False):
return jt.nn.dropout(x, p, training)
class Flatten(Module):
''' Flattens the contiguous range of dimensions in a Var.
:param start_dim: the first dimension to be flattened. Defaults: 1.
:type start_dim: int
:param end_dim: the last dimension to be flattened. Defaults: -1.
:type end_dim: int
'''
def __init__(self, start_dim=1, end_dim=-1):
self.start_dim = start_dim
self.end_dim = end_dim
def forward(self, x) -> jt.Var:
return x.flatten(self.start_dim, self.end_dim)
class _IncompatibleKeys:
def __init__(self, missing_keys, unexpected_keys):
self.missing_keys = missing_keys
self.unexpected_keys = unexpected_keys
_BatchNorm = None
#from . import utils
normalize = wrapper(jt.normalize)
T = TypeVar('T', bound=Module)
class ModuleDict(Module):
_modules: Dict[str, Module] # type: ignore[assignment]
def __init__(self, modules: Optional[Mapping[str, Module]] = None) -> None:
super().__init__()
if modules is not None:
self.update(modules)
def __getitem__(self, key: str) -> Module:
return self._modules[key]
def __setitem__(self, key: str, module: Module) -> None:
self.add_module(key, module)
def __delitem__(self, key: str) -> None:
del self._modules[key]
def __len__(self) -> int:
return len(self._modules)
def __iter__(self) -> Iterator[str]:
return iter(self._modules)
def __contains__(self, key: str) -> bool:
return key in self._modules
def clear(self) -> None:
"""Remove all items from the ModuleDict."""
self._modules.clear()
def pop(self, key: str) -> Module:
r"""Remove key from the ModuleDict and return its module.
Args:
key (str): key to pop from the ModuleDict
"""
v = self[key]
del self[key]
return v
def keys(self) -> Iterable[str]:
r"""Return an iterable of the ModuleDict keys."""
return self._modules.keys()
def items(self) -> Iterable[Tuple[str, Module]]:
r"""Return an iterable of the ModuleDict key/value pairs."""
return self._modules.items()
def values(self) -> Iterable[Module]:
r"""Return an iterable of the ModuleDict values."""
return self._modules.values()
def update(self, modules: Mapping[str, Module]) -> None:
r"""Update the :class:`~torch.nn.ModuleDict` with key-value pairs from a mapping, overwriting existing keys.
.. note::
If :attr:`modules` is an ``OrderedDict``, a :class:`~torch.nn.ModuleDict`, or
an iterable of key-value pairs, the order of new elements in it is preserved.
Args:
modules (iterable): a mapping (dictionary) from string to :class:`~torch.nn.Module`,
or an iterable of key-value pairs of type (string, :class:`~torch.nn.Module`)
"""
if not isinstance(modules, container_abcs.Iterable):
raise TypeError("ModuleDict.update should be called with an "
"iterable of key/value pairs, but got " +
type(modules).__name__)
if isinstance(modules, (OrderedDict, ModuleDict, container_abcs.Mapping)):
for key, module in modules.items():
self[key] = module
else:
# modules here can be a list with two items
for j, m in enumerate(modules):
if not isinstance(m, container_abcs.Iterable):
raise TypeError("ModuleDict update sequence element "
"#" + str(j) + " should be Iterable; is" +
type(m).__name__)
if not len(m) == 2:
raise ValueError("ModuleDict update sequence element "
"#" + str(j) + " has length " + str(len(m)) +
"; 2 is required")
# modules can be Mapping (what it's typed at), or a list: [(name1, module1), (name2, module2)]
# that's too cumbersome to type correctly with overloads, so we add an ignore here
self[m[0]] = m[1] # type: ignore[assignment]
# remove forward alltogether to fallback on Module's _forward_unimplemented
class ParameterList(Module):
def __init__(self, values: Optional[Iterable[Any]] = None) -> None:
super().__init__()
self._size = 0
if values is not None:
self += values
def _get_abs_string_index(self, idx):
"""Get the absolute index for the list of modules."""
idx = operator.index(idx)
if not (-len(self) <= idx < len(self)):
raise IndexError(f'index {idx} is out of range')
if idx < 0:
idx += len(self)
return str(idx)
@overload
def __getitem__(self, idx: int) -> Any:
...
@overload
def __getitem__(self: T, idx: slice) -> T:
...
def __getitem__(self, idx):
if isinstance(idx, slice):
start, stop, step = idx.indices(len(self))
out = self.__class__()
for i in range(start, stop, step):
out.append(self[i])
return out
else:
idx = self._get_abs_string_index(idx)
return getattr(self, str(idx))
def __setitem__(self, idx: int, param: Any) -> None:
# Note that all other function that add an entry to the list part of
# the ParameterList end up here. So this is the only place where we need
# to wrap things into Parameter if needed.
# Objects added via setattr() are not in the list part and thus won't
# call into this function.
idx = self._get_abs_string_index(idx)
if isinstance(param, jt.Var) and not isinstance(param, Parameter):
param = Parameter(param)
return setattr(self, str(idx), param)
def __len__(self) -> int:
return self._size
def __iter__(self) -> Iterator[Any]:
return iter(self[i] for i in range(len(self)))
def __iadd__(self, parameters: Iterable[Any]) -> Self:
return self.extend(parameters)
def __dir__(self):
keys = super().__dir__()
keys = [key for key in keys if not key.isdigit()]
return keys
def append(self, value: Any) -> 'ParameterList':
"""Append a given value at the end of the list.
Args:
value (Any): value to append
"""
new_idx = len(self)
self._size += 1
self[new_idx] = value
return self
def extend(self, values: Iterable[Any]) -> Self:
"""Append values from a Python iterable to the end of the list.
Args:
values (iterable): iterable of values to append
"""
# Tensor is an iterable but we never want to unpack it here
if not isinstance(values, container_abcs.Iterable) or isinstance(values, jt.Var):
raise TypeError("ParameterList.extend should be called with an "
"iterable, but got " + type(values).__name__)
for value in values:
self.append(value)
return self
def extra_repr(self) -> str:
child_lines = []
for k, p in enumerate(self):
if isinstance(p, jt.Var):
size_str = 'x'.join(str(size) for size in p.size())
parastr = '{} containing: [{} of size {}{}]'.format(
"Parameter" if isinstance(p, Parameter) else "Tensor",
p.dtype, size_str, "cuda" if jt.flags.use_cuda else "cpu")
child_lines.append(' (' + str(k) + '): ' + parastr)
else:
child_lines.append(' (' + str(k) + '): Object of type: ' + type(p).__name__)
tmpstr = '\n'.join(child_lines)
return tmpstr
def __call__(self, *args, **kwargs):
raise RuntimeError('ParameterList should not be called.')

View File

@ -1,16 +0,0 @@
import jittor as jt
for k,v in jt.nn.init.__dict__.items():
if callable(v):
globals()[k] = v
normal = gauss
normal_ = gauss_
xavier_normal = xavier_gauss
xavier_normal_ = xavier_gauss_
zeros_ = zero_
jt.Var.normal_ = normal_

View File

@ -1 +0,0 @@
from . import rnn

View File

@ -1,20 +0,0 @@
import jittor as jt
PackedSequence = None
def pad_sequence(sequences,batch_first=False,padding_value=0.0):
max_f = max([len(s) for s in sequences])
# max_f = 512
b = len(sequences)
if batch_first:
ret = sequences[0].new_full([b,max_f,]+list(sequences[0].shape[1:]),padding_value)
for i,s in enumerate(sequences):
ret[i,:len(s)] = s
else:
ret = sequences[0].new_full([max_f,b,]+list(sequences[0].shape[1:]),padding_value)
for i,s in enumerate(sequences):
ret[:len(s),i] = s
# print(ret.shape)
# ret = ret[:,:406]
return ret

File diff suppressed because it is too large Load Diff

View File

@ -1,102 +0,0 @@
#include "pyjt/py_obj_holder.h"
#include "utils/str_utils.h"
#include "jtorch_core.h"
#include "graph.h"
#include "grad.h"
#include "ops/op_register.h"
namespace jittor {
void pyjt_def_all(PyObject* m);
EXTERN_LIB void setter_use_cuda(int value);
Device::Device(const string& name, int ordinal) : name(name) {
if (startswith(name, "cpu"))
setter_use_cuda(0);
else
setter_use_cuda(1);
}
unordered_map<int64, VarPtr> grad_backup;
EXTERN_LIB void (*_var_free_hook)(Var*);
EXTERN_LIB unordered_map<int64, VarPtr>* _grad_backup_ptr;
void jtorch_var_free_hook(Var* v) {
auto iter = grad_backup.find(v->id);
if (iter != grad_backup.end()) {
grad_backup.erase(iter);
}
}
void jtorch_init() {
_var_free_hook = &jtorch_var_free_hook;
_grad_backup_ptr = &grad_backup;
}
inline static VarPtr& get_grad(Var* v) {
return grad_backup[v->id];
}
static auto make_binary = get_op_info("binary")
.get_constructor<VarPtr, Var*, Var*, NanoString>();
inline static void add_grad(VarPtr& a, VarPtr&& b) {
if (!a) a = move(b);
else {
a = make_binary(a, b, ns_add);
}
}
void grad_set(VarHolder* x, Maybe<VarHolder> v) {
if (!v) {
grad_del(x);
return;
}
grad_backup[x->var->id] = v.ptr->var;
}
Maybe<VarHolder> grad_get(VarHolder* x) {
auto iter = grad_backup.find(x->var->id);
if (iter != grad_backup.end()) {
if (!iter->second.ptr) return nullptr;
return new VarHolder(iter->second.ptr);
}
return nullptr;
}
void grad_del(VarHolder* x) {
auto iter = grad_backup.find(x->var->id);
if (iter != grad_backup.end())
grad_backup.erase(iter);
}
void backward(VarHolder* x) {
vector<Node*> gnodes({x->var});
bfs_backward(gnodes, [&](Node* node) {
if (node->is_stop_grad())
return false;
return true;
});
vector<Var*> targets;
for (auto* node : gnodes) {
if (node->is_var() && node->flags.get(NodeFlags::_th_require_grad))
targets.push_back(node->var());
}
auto grads = grad(x->var, targets);
for (int i=0; i<targets.size(); i++) {
auto& gptr = get_grad(targets[i]);
add_grad(gptr, move(grads[i]));
}
}
}
static void init_module(PyModuleDef* mdef, PyObject* m) {
jittor::jtorch_init();
mdef->m_doc = "Inner c++ core of jtorch";
jittor::pyjt_def_all(m);
}
PYJT_MODULE_INIT(jtorch_core);

View File

@ -1,40 +0,0 @@
#pragma once
#include "common.h"
#include "var_holder.h"
#include "misc/fast_shared_ptr.h"
namespace jittor {
// @pyjt(device)
// @attrs(heaptype)
struct Device {
string name;
// @pyjt(__init__)
Device(const string& name, int ordinal=0);
// @pyjt(__get__type, __str__)
inline string get_type() {return name;}
// @pyjt(__get__index)
inline int index() {return 0;}
};
// @pyjt(backward)
void backward(VarHolder* x);
// @pyjt(grad_set)
void grad_set(VarHolder* x, Maybe<VarHolder> v);
// @pyjt(grad_get)
Maybe<VarHolder> grad_get(VarHolder* x);
// @pyjt(grad_del)
void grad_del(VarHolder* x);
// @pyjt(retain_grad_set)
inline void retain_grad_set(VarHolder* x, bool v) {
x->var->flags.set(NodeFlags::_th_require_grad, v);
}
// @pyjt(retain_grad_get)
inline bool retain_grad_get(VarHolder* x) {
return x->var->flags.get(NodeFlags::_th_require_grad);
}
}

View File

@ -1,25 +0,0 @@
import unittest
import numpy as np
import torch
import jittor as jt
class TestConflictFunc(unittest.TestCase):
def test_max(self):
a = torch.Tensor([1,4,2])
assert a.max() == 4
v, k = a.max(dim=0)
assert v==4 and k==1
def test_argsort(self):
a = torch.Tensor([1,4,2])
k = a.argsort()
assert jt.all_equal(k, [0,2,1])
with jt.flag_scope(th_mode=0):
k, v = a.argsort()
assert jt.all_equal(k, [0,2,1])
if __name__ == "__main__":
unittest.main()

View File

@ -1,58 +0,0 @@
import unittest
import numpy as np
import torch
class TestFunction(unittest.TestCase):
def test_example1(self):
import jtorch
from jtorch import Function
class MyFunc(Function):
@staticmethod
def forward(self, x, y):
self.x = x
self.y = y
return x*y, x/y
@staticmethod
def backward(self, grad0, grad1):
return grad0 * self.y, grad1 * self.x
a = jtorch.array(3.0)
a.requires_grad = True
b = jtorch.array(4.0)
b.requires_grad = True
func = MyFunc.apply
c,d = func(a, b)
(c+d*3).backward()
assert a.grad.data == 4
assert b.grad.data == 9
def test_example2(self):
import jtorch as jt
from jtorch import Function
class MyFunc(Function):
@staticmethod
def forward(self, x, y):
self.x = x
self.y = y
return x*y, x/y
@staticmethod
def backward(self, grad0, grad1):
assert grad1 is None
return grad0 * self.y, None
a = jt.array(3.0)
a.requires_grad = True
b = jt.array(4.0)
b.requires_grad = True
func = MyFunc.apply
c,d = func(a, b)
d.stop_grad()
da, db = jt.grad(c+d*3, [a, b])
assert da.data == 4
assert db.data == 0
if __name__ == "__main__":
unittest.main()

View File

@ -1,24 +0,0 @@
import unittest
import numpy as np
import torch
class TestMisc(unittest.TestCase):
def test_update_grad(self):
class Net(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = torch.nn.Parameter(torch.Tensor([1.0, 2.0]))
net = Net()
assert(net.a.requires_grad)
net.load_state_dict({"a": torch.Tensor([3.0, 4.0])})
assert(net.a.requires_grad)
def test_reshape(self):
a = torch.ones(3,3)
a.requires_grad = True
b = torch.reshape(a, [9])
assert b.requires_grad == True
if __name__ == "__main__":
unittest.main()

View File

@ -1,56 +0,0 @@
import unittest
import numpy as np
import os
import subprocess as sp
import sys
def check_two(cmd, parser=None, checker=None):
jtorch_out = sp.getoutput(cmd)
print("=========JTORCH OUT==========")
print(jtorch_out)
torch_out = sp.getoutput("PYTHONPATH= "+cmd)
print("=========TORCH OUT==========")
print(torch_out)
if parser:
torch_out = parser(torch_out)
jtorch_out = parser(jtorch_out)
if checker:
checker(torch_out, jtorch_out)
else:
assert torch_out == jtorch_out
return jtorch_out, torch_out
jtorch_path = os.path.join(os.path.dirname(__file__), "..")
# come from https://pytorch.org/tutorials/beginner/pytorch_with_examples.html
class TestTutorial(unittest.TestCase):
def test_auto_grad1(self):
check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad1.py",
parser=lambda s: np.array(s.split())[[-10,-8,-5,-2]].astype(float),
checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-4))
def test_auto_grad2(self):
check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad2.py",
parser=lambda s: np.array(s.split())[[-10,-8,-5,-2]].astype(float),
checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-4))
def test_auto_grad3(self):
check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad3.py",
parser=lambda s: np.array(s.split())[[-9,-7,-4,-2]].astype(float),
checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-4))
def test_auto_grad4(self):
check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad4.py",
parser=lambda s: np.array(s.split())[[-10,-8,-5,-2]].astype(float),
checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-4))
def test_auto_grad5(self):
check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad5_optim.py",
parser=lambda s: np.array(s.split())[[-10,-8,-5,-2]].astype(float),
checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-2))
def test_auto_grad6(self):
check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad6_module.py",
parser=lambda s: np.array(s.split())[[-10,-8,-5,-2]].astype(float),
checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-4))
def test_auto_grad7(self):
check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad7_dynet.py",
parser=lambda s: np.array(s.split())[[-13,-10,-7,-3]].astype(float),
checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-2))
if __name__ == "__main__":
unittest.main()

View File

@ -1,44 +0,0 @@
import torch
import math
dtype = torch.float
device = torch.device("cpu")
# device = torch.device("cuda:0") # Uncomment this to run on GPU
# Create random input and output data
x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
y = torch.sin(x)
# Randomly initialize weights
a = torch.randn((), device=device, dtype=dtype)
b = torch.randn((), device=device, dtype=dtype)
c = torch.randn((), device=device, dtype=dtype)
d = torch.randn((), device=device, dtype=dtype)
learning_rate = 1e-6
for t in range(20000):
# Forward pass: compute predicted y
y_pred = a + b * x + c * x ** 2 + d * x ** 3
# Compute and print loss
loss = (y_pred - y).pow(2).sum().item()
if t % 1000 == 999:
print(t, loss)
# Backprop to compute gradients of a, b, c, d with respect to loss
grad_y_pred = 2.0 * (y_pred - y)
grad_a = grad_y_pred.sum()
grad_b = (grad_y_pred * x).sum()
grad_c = (grad_y_pred * x ** 2).sum()
grad_d = (grad_y_pred * x ** 3).sum()
# Update weights using gradient descent
a -= learning_rate * grad_a
b -= learning_rate * grad_b
c -= learning_rate * grad_c
d -= learning_rate * grad_d
# print(t, torch.liveness_info())
# torch.sync_all()
print(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3')

View File

@ -1,60 +0,0 @@
# -*- coding: utf-8 -*-
import torch
import math
dtype = torch.float
device = torch.device("cpu")
# device = torch.device("cuda:0") # Uncomment this to run on GPU
# Create Tensors to hold input and outputs.
# By default, requires_grad=False, which indicates that we do not need to
# compute gradients with respect to these Tensors during the backward pass.
x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
y = torch.sin(x)
# Create random Tensors for weights. For a third order polynomial, we need
# 4 weights: y = a + b x + c x^2 + d x^3
# Setting requires_grad=True indicates that we want to compute gradients with
# respect to these Tensors during the backward pass.
a = torch.randn((), device=device, dtype=dtype, requires_grad=True)
b = torch.randn((), device=device, dtype=dtype, requires_grad=True)
c = torch.randn((), device=device, dtype=dtype, requires_grad=True)
d = torch.randn((), device=device, dtype=dtype, requires_grad=True)
learning_rate = 1e-6
for t in range(20000):
# Forward pass: compute predicted y using operations on Tensors.
y_pred = a + b * x + c * x ** 2 + d * x ** 3
# print(y_pred.requires_grad)
# y_pred.requires_grad = False
# Compute and print loss using operations on Tensors.
# Now loss is a Tensor of shape (1,)
# loss.item() gets the scalar value held in the loss.
loss = (y_pred - y).pow(2).sum()
if t % 1000 == 990:
print(t, loss.item())
# Use autograd to compute the backward pass. This call will compute the
# gradient of loss with respect to all Tensors with requires_grad=True.
# After this call a.grad, b.grad. c.grad and d.grad will be Tensors holding
# the gradient of the loss with respect to a, b, c, d respectively.
# torch.backward(loss)
loss.backward()
# Manually update weights using gradient descent. Wrap in torch.no_grad()
# because weights have requires_grad=True, but we don't need to track this
# in autograd.
with torch.no_grad():
a -= learning_rate * a.grad
b -= learning_rate * b.grad
c -= learning_rate * c.grad
d -= learning_rate * d.grad
# Manually zero the gradients after updating weights
a.grad = None
b.grad = None
c.grad = None
d.grad = None
print(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3')

View File

@ -1,85 +0,0 @@
# -*- coding: utf-8 -*-
import torch
import math
class LegendrePolynomial3(torch.autograd.Function):
"""
We can implement our own custom autograd Functions by subclassing
torch.autograd.Function and implementing the forward and backward passes
which operate on Tensors.
"""
@staticmethod
def forward(ctx, input):
"""
In the forward pass we receive a Tensor containing the input and return
a Tensor containing the output. ctx is a context object that can be used
to stash information for backward computation. You can cache arbitrary
objects for use in the backward pass using the ctx.save_for_backward method.
"""
ctx.save_for_backward(input)
return 0.5 * (5 * input ** 3 - 3 * input)
@staticmethod
def backward(ctx, grad_output):
"""
In the backward pass we receive a Tensor containing the gradient of the loss
with respect to the output, and we need to compute the gradient of the loss
with respect to the input.
"""
input, = ctx.saved_tensors
return grad_output * 1.5 * (5 * input ** 2 - 1)
dtype = torch.float
device = torch.device("cpu")
# device = torch.device("cuda:0") # Uncomment this to run on GPU
# Create Tensors to hold input and outputs.
# By default, requires_grad=False, which indicates that we do not need to
# compute gradients with respect to these Tensors during the backward pass.
x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
y = torch.sin(x)
# Create random Tensors for weights. For this example, we need
# 4 weights: y = a + b * P3(c + d * x), these weights need to be initialized
# not too far from the correct result to ensure convergence.
# Setting requires_grad=True indicates that we want to compute gradients with
# respect to these Tensors during the backward pass.
a = torch.full((), 0.0, device=device, dtype=dtype, requires_grad=True)
b = torch.full((), -1.0, device=device, dtype=dtype, requires_grad=True)
c = torch.full((), 0.0, device=device, dtype=dtype, requires_grad=True)
d = torch.full((), 0.3, device=device, dtype=dtype, requires_grad=True)
learning_rate = 5e-6
for t in range(2000):
# To apply our Function, we use Function.apply method. We alias this as 'P3'.
P3 = LegendrePolynomial3.apply
# Forward pass: compute predicted y using operations; we compute
# P3 using our custom autograd operation.
y_pred = a + b * P3(c + d * x)
# Compute and print loss
loss = (y_pred - y).pow(2).sum()
if t % 100 == 99:
print(t, loss.item())
# Use autograd to compute the backward pass.
loss.backward()
# Update weights using gradient descent
with torch.no_grad():
a -= learning_rate * a.grad
b -= learning_rate * b.grad
c -= learning_rate * c.grad
d -= learning_rate * d.grad
# Manually zero the gradients after updating weights
a.grad = None
b.grad = None
c.grad = None
d.grad = None
print(f'Result: y = {a.item()} + {b.item()} * P3( {c.item()} + {d.item()} x)')

View File

@ -1,71 +0,0 @@
# -*- coding: utf-8 -*-
import torch
import math
# Create Tensors to hold input and outputs.
x = torch.linspace(-math.pi, math.pi, 2000)
y = torch.sin(x)
# For this example, the output y is a linear function of (x, x^2, x^3), so
# we can consider it as a linear layer neural network. Let's prepare the
# tensor (x, x^2, x^3).
p = torch.tensor([1, 2, 3])
xx = x.unsqueeze(-1).pow(p)
# In the above code, x.unsqueeze(-1) has shape (2000, 1), and p has shape
# (3,), for this case, broadcasting semantics will apply to obtain a tensor
# of shape (2000, 3)
# Use the nn package to define our model as a sequence of layers. nn.Sequential
# is a Module which contains other Modules, and applies them in sequence to
# produce its output. The Linear Module computes output from input using a
# linear function, and holds internal Tensors for its weight and bias.
# The Flatten layer flatens the output of the linear layer to a 1D tensor,
# to match the shape of `y`.
model = torch.nn.Sequential(
torch.nn.Linear(3, 1),
torch.nn.Flatten(0, 1)
)
# The nn package also contains definitions of popular loss functions; in this
# case we will use Mean Squared Error (MSE) as our loss function.
loss_fn = torch.nn.MSELoss(reduction='sum')
# print(model[0].weight.requires_grad)
learning_rate = 1e-6
for t in range(8000):
# Forward pass: compute predicted y by passing x to the model. Module objects
# override the __call__ operator so you can call them like functions. When
# doing so you pass a Tensor of input data to the Module and it produces
# a Tensor of output data.
y_pred = model(xx)
# Compute and print loss. We pass Tensors containing the predicted and true
# values of y, and the loss function returns a Tensor containing the
# loss.
loss = loss_fn(y_pred, y)
if t % 1000 == 999:
print(t, loss.item())
# Zero the gradients before running the backward pass.
model.zero_grad()
# Backward pass: compute gradient of the loss with respect to all the learnable
# parameters of the model. Internally, the parameters of each Module are stored
# in Tensors with requires_grad=True, so this call will compute gradients for
# all learnable parameters in the model.
loss.backward()
# Update the weights using gradient descent. Each parameter is a Tensor, so
# we can access its gradients like we did before.
with torch.no_grad():
for param in model.parameters():
param -= learning_rate * param.grad
# You can access the first layer of `model` like accessing the first item of a list
linear_layer = model[0]
# For linear layer, its parameters are stored as `weight` and `bias`.
print(f'Result: y = {linear_layer.bias.item()} + {linear_layer.weight[:, 0].item()} x + {linear_layer.weight[:, 1].item()} x^2 + {linear_layer.weight[:, 2].item()} x^3')

View File

@ -1,53 +0,0 @@
# -*- coding: utf-8 -*-
import torch
import math
# Create Tensors to hold input and outputs.
x = torch.linspace(-math.pi, math.pi, 2000)
y = torch.sin(x)
# Prepare the input tensor (x, x^2, x^3).
p = torch.tensor([1, 2, 3])
xx = x.unsqueeze(-1).pow(p)
# Use the nn package to define our model and loss function.
model = torch.nn.Sequential(
torch.nn.Linear(3, 1),
torch.nn.Flatten(0, 1)
)
loss_fn = torch.nn.MSELoss(reduction='sum')
# Use the optim package to define an Optimizer that will update the weights of
# the model for us. Here we will use RMSprop; the optim package contains many other
# optimization algorithms. The first argument to the RMSprop constructor tells the
# optimizer which Tensors it should update.
learning_rate = 1e-3
optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate)
for t in range(8000):
# Forward pass: compute predicted y by passing x to the model.
y_pred = model(xx)
# Compute and print loss.
loss = loss_fn(y_pred, y)
if t % 1000 == 999:
print(t, loss.item())
# Before the backward pass, use the optimizer object to zero all of the
# gradients for the variables it will update (which are the learnable
# weights of the model). This is because by default, gradients are
# accumulated in buffers( i.e, not overwritten) whenever .backward()
# is called. Checkout docs of torch.autograd.backward for more details.
optimizer.zero_grad()
# Backward pass: compute gradient of the loss with respect to model
# parameters
loss.backward()
# Calling the step function on an Optimizer makes an update to its
# parameters
optimizer.step()
linear_layer = model[0]
print(f'Result: y = {linear_layer.bias.item()} + {linear_layer.weight[:, 0].item()} x + {linear_layer.weight[:, 1].item()} x^2 + {linear_layer.weight[:, 2].item()} x^3')

View File

@ -1,59 +0,0 @@
# -*- coding: utf-8 -*-
import torch
import math
class Polynomial3(torch.nn.Module):
def __init__(self):
"""
In the constructor we instantiate four parameters and assign them as
member parameters.
"""
super().__init__()
self.a = torch.nn.Parameter(torch.randn(()))
self.b = torch.nn.Parameter(torch.randn(()))
self.c = torch.nn.Parameter(torch.randn(()))
self.d = torch.nn.Parameter(torch.randn(()))
def forward(self, x):
"""
In the forward function we accept a Tensor of input data and we must return
a Tensor of output data. We can use Modules defined in the constructor as
well as arbitrary operators on Tensors.
"""
return self.a + self.b * x + self.c * x ** 2 + self.d * x ** 3
def string(self):
"""
Just like any class in Python, you can also define custom method on PyTorch modules
"""
return f'y = {self.a.item()} + {self.b.item()} x + {self.c.item()} x^2 + {self.d.item()} x^3'
# Create Tensors to hold input and outputs.
x = torch.linspace(-math.pi, math.pi, 2000)
y = torch.sin(x)
# Construct our model by instantiating the class defined above
model = Polynomial3()
# Construct our loss function and an Optimizer. The call to model.parameters()
# in the SGD constructor will contain the learnable parameters (defined
# with torch.nn.Parameter) which are members of the model.
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-6)
for t in range(8000):
# Forward pass: Compute predicted y by passing x to the model
y_pred = model(x)
# Compute and print loss
loss = criterion(y_pred, y)
if t % 1000 == 999:
print(t, loss.item())
# Zero gradients, perform a backward pass, and update the weights.
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f'Result: {model.string()}')

View File

@ -1,69 +0,0 @@
# -*- coding: utf-8 -*-
import random
import torch
import math
class DynamicNet(torch.nn.Module):
def __init__(self):
"""
In the constructor we instantiate five parameters and assign them as members.
"""
super().__init__()
self.a = torch.nn.Parameter(torch.randn(()))
self.b = torch.nn.Parameter(torch.randn(()))
self.c = torch.nn.Parameter(torch.randn(()))
self.d = torch.nn.Parameter(torch.randn(()))
self.e = torch.nn.Parameter(torch.randn(()))
def forward(self, x):
"""
For the forward pass of the model, we randomly choose either 4, 5
and reuse the e parameter to compute the contribution of these orders.
Since each forward pass builds a dynamic computation graph, we can use normal
Python control-flow operators like loops or conditional statements when
defining the forward pass of the model.
Here we also see that it is perfectly safe to reuse the same parameter many
times when defining a computational graph.
"""
y = self.a + self.b * x + self.c * x ** 2 + self.d * x ** 3
for exp in range(4, random.randint(4, 6)):
y = y + self.e * x ** exp
return y
def string(self):
"""
Just like any class in Python, you can also define custom method on PyTorch modules
"""
return f'y = {self.a.item()} + {self.b.item()} x + {self.c.item()} x^2 + {self.d.item()} x^3 + {self.e.item()} x^4 ? + {self.e.item()} x^5 ?'
# Create Tensors to hold input and outputs.
x = torch.linspace(-math.pi, math.pi, 2000)
y = torch.sin(x)
# Construct our model by instantiating the class defined above
model = DynamicNet()
# Construct our loss function and an Optimizer. Training this strange model with
# vanilla stochastic gradient descent is tough, so we use momentum
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-8, momentum=0.9)
for t in range(60000):
# Forward pass: Compute predicted y by passing x to the model
y_pred = model(x)
# Compute and print loss
loss = criterion(y_pred, y)
if t % 2000 == 1999:
print(t, loss.item())
# Zero gradients, perform a backward pass, and update the weights.
optimizer.zero_grad()
loss.backward()
optimizer.step()
# print(torch.liveness_info())
print(f'Result: {model.string()}')

View File

@ -1,106 +0,0 @@
import torch
from torch import nn
# from jtorch.utils import DataLoader
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
# Download training data from open datasets.
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
)
# Download test data from open datasets.
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor(),
)
batch_size = 64
# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)
print(len(train_dataloader))
for X, y in test_dataloader:
print(f"Shape of X [N, C, H, W]: {X.shape}")
print(f"Shape of y: {y.shape} {y.dtype}")
break
# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")
# Define model
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10)
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
model = NeuralNetwork().to(device)
print(model)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
def train(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
model.train()
for batch, (X, y) in enumerate(dataloader):
X, y = X.to(device), y.to(device)
# Compute prediction error
pred = model(X)
loss = loss_fn(pred, y)
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
def test(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
epochs = 5
test(test_dataloader, model, loss_fn)
for t in range(epochs):
print(f"Epoch {t+1}\n-------------------------------")
train(train_dataloader, model, loss_fn, optimizer)
test(test_dataloader, model, loss_fn)
print("Done!")

View File

@ -1,5 +0,0 @@
cpp_extension = None
_flatten_dense_tensors = None
_unflatten_dense_tensors = None
tensorboard = None

View File

@ -1,3 +0,0 @@
#TODO: Implement this
_register_pytree_node = None
_dict_flatten = None

View File

@ -1,8 +0,0 @@
detach_variable = None
def checkpoint(
*args,
**kwargs
):
pass

View File

@ -1,137 +0,0 @@
import jittor as jt
import jittor.dataset
from jittor.dataset import Dataset as JDataset
from collections import namedtuple
from typing import Any, Callable, Iterable, Optional, Sequence, Union
class Dataset:
def __getitem__(self, index):
raise NotImplementedError
class IterableDataset:
def __iter__(self):
raise NotImplementedError
class DataLoader(JDataset):
def __init__(self, dataset,
batch_size: Optional[int] = 1,
shuffle: Optional[bool] = False,
sampler = None,
batch_sampler = None,
num_workers: int = 0,
collate_fn = None,
pin_memory: bool = False,
drop_last: bool = False,
timeout: float = 0,
worker_init_fn = None,
multiprocessing_context=None,
generator=None,
*, prefetch_factor: int = 2,
persistent_workers: bool = False,
pin_memory_device: str = "") -> None:
super().__init__(batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
drop_last=drop_last)
unsupported_kwargs = {
"batch_sampler": batch_sampler,
"pin_memory": pin_memory,
"timeout": timeout,
"worker_init_fn": worker_init_fn,
"multiprocessing_context": multiprocessing_context,
"generator": generator,
"persistent_workers": persistent_workers,
"pin_memory_device": pin_memory_device
}
for kwarg, value in unsupported_kwargs.items():
if value:
jt.LOG.w(f"Not implemented Dataloader kwarg: {kwarg}")
self.dataset = dataset
self.collate_fn = collate_fn
self.sampler = sampler
if not isinstance(dataset, IterableDataset):
self.total_len = len(dataset)
else:
# TODO: support multiple worker for iterable dataset
assert(num_workers == 0)
def collate_batch(self, batch):
if self.collate_fn is not None:
return self.collate_fn(batch)
else:
return super().collate_batch(batch)
def __getitem__(self, i):
return self.dataset[i]
def __iter__(self):
if isinstance(self.dataset, IterableDataset):
return self.inner_iter()
else:
return super().__iter__()
def inner_iter(self):
current_batch = []
if jt.world_size > 1:
assert self.batch_size % jt.world_size == 0, \
f"IterableDataset does not support a batch size ({self.batch_size}) that is not evenly divisible by the number of processes f{jt.world_size}"
real_batch_size = int(self.batch_size / jt.world_size)
else:
real_batch_size = self.batch_size
for element in self.dataset:
current_batch.append(element)
if len(current_batch) == real_batch_size:
current_batch = self.collate_batch(current_batch)
current_batch = self.to_jittor(current_batch)
yield current_batch
current_batch = []
if not self.drop_last and len(current_batch) > 0:
current_batch = self.collate_batch(current_batch)
yield self.to_jittor(current_batch)
# def get_worker_info():
# # always return the fake worker info
# return namedtuple('WorkerInfo', 'id num_workers')(0, 1)
# class RandomSampler(jt.dataset.RandomSampler):
# def __init__(self, dataset, generator=None, **kwargs):
# super().__init__(dataset, **kwargs)
# def __iter__(self):
# if getattr(self.dataset, "support_random_access", True):
# return super().__iter__()
# else:
# self.dataset.shuffle()
# return iter(range(self.dataset.__real_len__() if hasattr(self.dataset,"__real_len__") else self.dataset.__len__()))
# class DistributedSampler(jt.dataset.Sampler):
# def __init__(self, sampler: RandomSampler):
# assert(isinstance(sampler, RandomSampler))
# self.sampler = sampler
# def set_epoch(self, epoch: int):
# ### do nothing, let jittor's inner dataset handle
# pass
# def __iter__(self):
# return self.sampler.__iter__()
# def __len__(self):
# return self.sampler.__len__()
# BatchSampler = jt.dataset.BatchSampler
# Sampler = jt.dataset.Sampler
# SequentialSampler = jt.dataset.SequentialSampler
# SubsetRandomSampler = jt.dataset.SubsetRandomSampler
# TensorDataset = Dataset

View File

@ -1,9 +0,0 @@
from typing import Callable, Union
Dtype = Union[Callable, str]
def get_string_dtype(dtype):
if callable(dtype):
dtype = dtype.__name__
if not isinstance(dtype, str):
raise ValueError(f"dtype is expected to be str, python type function, or jittor type function, but got {dtype}.")
return dtype

View File

@ -1,34 +0,0 @@
import os
import glob
import shutil
import sys
home_path = os.path.join(os.path.dirname(__file__), "..", "..", "..")
home_path = os.path.abspath(home_path)
def callback(func, path, exc_info):
print(f"remove \"{path}\" failed.")
def rmtree(path):
if os.path.isdir(path):
print(f"remove \"{path}\" recursive.")
shutil.rmtree(path, onerror=callback)
def remove_tmpfile():
dist_file = home_path+"/dist"
egg_file = glob.glob(home_path+"/**/*egg-info")
rmtree(dist_file)
for e in egg_file:
rmtree(e)
def run_cmd(cmd):
print("[CMD]", cmd)
assert os.system(cmd)==0
os.chdir(home_path)
remove_tmpfile()
run_cmd(f"{sys.executable} ./setup.py sdist")
run_cmd(f"{sys.executable} -m twine upload dist/*")
remove_tmpfile()

View File

@ -1,46 +0,0 @@
import importlib.machinery
import os
def _download_file_from_remote_location(fpath: str, url: str) -> None:
pass
def _is_remote_location_available() -> bool:
return False
def _get_extension_path(lib_name):
lib_dir = os.path.dirname(__file__)
if os.name == "nt":
# Register the main torchvision library location on the default DLL path
import ctypes
import sys
kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True)
with_load_library_flags = hasattr(kernel32, "AddDllDirectory")
prev_error_mode = kernel32.SetErrorMode(0x0001)
if with_load_library_flags:
kernel32.AddDllDirectory.restype = ctypes.c_void_p
if sys.version_info >= (3, 8):
os.add_dll_directory(lib_dir)
elif with_load_library_flags:
res = kernel32.AddDllDirectory(lib_dir)
if res is None:
err = ctypes.WinError(ctypes.get_last_error())
err.strerror += f' Error adding "{lib_dir}" to the DLL directories.'
raise err
kernel32.SetErrorMode(prev_error_mode)
loader_details = (importlib.machinery.ExtensionFileLoader, importlib.machinery.EXTENSION_SUFFIXES)
extfinder = importlib.machinery.FileFinder(lib_dir, loader_details)
ext_specs = extfinder.find_spec(lib_name)
if ext_specs is None:
raise ImportError
return ext_specs.origin

View File

@ -1,9 +0,0 @@
from .mnist import EMNIST, FashionMNIST, KMNIST, MNIST, QMNIST
__all__ = (
"EMNIST",
"FashionMNIST",
"QMNIST",
"MNIST",
"KMNIST",
)

View File

@ -1,558 +0,0 @@
import codecs
import os
import os.path
import shutil
import string
import sys
import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple
from urllib.error import URLError
import numpy as np
import torch
from PIL import Image
from .utils import check_integrity, download_and_extract_archive, extract_archive, verify_str_arg
from .vision import VisionDataset
class MNIST(VisionDataset):
"""`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.
Args:
root (string): Root directory of dataset where ``MNIST/raw/train-images-idx3-ubyte``
and ``MNIST/raw/t10k-images-idx3-ubyte`` exist.
train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``,
otherwise from ``t10k-images-idx3-ubyte``.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""
mirrors = [
"http://yann.lecun.com/exdb/mnist/",
"https://ossci-datasets.s3.amazonaws.com/mnist/",
]
resources = [
("train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"),
("train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"),
("t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"),
("t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c"),
]
training_file = "training.pt"
test_file = "test.pt"
classes = [
"0 - zero",
"1 - one",
"2 - two",
"3 - three",
"4 - four",
"5 - five",
"6 - six",
"7 - seven",
"8 - eight",
"9 - nine",
]
@property
def train_labels(self):
warnings.warn("train_labels has been renamed targets")
return self.targets
@property
def test_labels(self):
warnings.warn("test_labels has been renamed targets")
return self.targets
@property
def train_data(self):
warnings.warn("train_data has been renamed data")
return self.data
@property
def test_data(self):
warnings.warn("test_data has been renamed data")
return self.data
def __init__(
self,
root: str,
train: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super().__init__(root, transform=transform, target_transform=target_transform)
self.train = train # training set or test set
if self._check_legacy_exist():
self.data, self.targets = self._load_legacy_data()
return
if download:
self.download()
if not self._check_exists():
raise RuntimeError("Dataset not found. You can use download=True to download it")
self.data, self.targets = self._load_data()
def _check_legacy_exist(self):
processed_folder_exists = os.path.exists(self.processed_folder)
if not processed_folder_exists:
return False
return all(
check_integrity(os.path.join(self.processed_folder, file)) for file in (self.training_file, self.test_file)
)
def _load_legacy_data(self):
# This is for BC only. We no longer cache the data in a custom binary, but simply read from the raw data
# directly.
data_file = self.training_file if self.train else self.test_file
return torch.load(os.path.join(self.processed_folder, data_file))
def _load_data(self):
image_file = f"{'train' if self.train else 't10k'}-images-idx3-ubyte"
data = read_image_file(os.path.join(self.raw_folder, image_file))
label_file = f"{'train' if self.train else 't10k'}-labels-idx1-ubyte"
targets = read_label_file(os.path.join(self.raw_folder, label_file))
return data, targets
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], int(self.targets[index])
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img.numpy(), mode="L")
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self) -> int:
return len(self.data)
@property
def raw_folder(self) -> str:
return os.path.join(self.root, self.__class__.__name__, "raw")
@property
def processed_folder(self) -> str:
return os.path.join(self.root, self.__class__.__name__, "processed")
@property
def class_to_idx(self) -> Dict[str, int]:
return {_class: i for i, _class in enumerate(self.classes)}
def _check_exists(self) -> bool:
return all(
check_integrity(os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0]))
for url, _ in self.resources
)
def download(self) -> None:
"""Download the MNIST data if it doesn't exist already."""
if self._check_exists():
return
os.makedirs(self.raw_folder, exist_ok=True)
# download files
for filename, md5 in self.resources:
for mirror in self.mirrors:
url = f"{mirror}{filename}"
try:
print(f"Downloading {url}")
download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5)
except URLError as error:
print(f"Failed to download (trying next):\n{error}")
continue
finally:
print()
break
else:
raise RuntimeError(f"Error downloading {filename}")
def extra_repr(self) -> str:
split = "Train" if self.train is True else "Test"
return f"Split: {split}"
class FashionMNIST(MNIST):
"""`Fashion-MNIST <https://github.com/zalandoresearch/fashion-mnist>`_ Dataset.
Args:
root (string): Root directory of dataset where ``FashionMNIST/raw/train-images-idx3-ubyte``
and ``FashionMNIST/raw/t10k-images-idx3-ubyte`` exist.
train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``,
otherwise from ``t10k-images-idx3-ubyte``.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""
mirrors = ["http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/"]
resources = [
("train-images-idx3-ubyte.gz", "8d4fb7e6c68d591d4c3dfef9ec88bf0d"),
("train-labels-idx1-ubyte.gz", "25c81989df183df01b3e8a0aad5dffbe"),
("t10k-images-idx3-ubyte.gz", "bef4ecab320f06d8554ea6380940ec79"),
("t10k-labels-idx1-ubyte.gz", "bb300cfdad3c16e7a12a480ee83cd310"),
]
classes = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]
class KMNIST(MNIST):
"""`Kuzushiji-MNIST <https://github.com/rois-codh/kmnist>`_ Dataset.
Args:
root (string): Root directory of dataset where ``KMNIST/raw/train-images-idx3-ubyte``
and ``KMNIST/raw/t10k-images-idx3-ubyte`` exist.
train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``,
otherwise from ``t10k-images-idx3-ubyte``.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""
mirrors = ["http://codh.rois.ac.jp/kmnist/dataset/kmnist/"]
resources = [
("train-images-idx3-ubyte.gz", "bdb82020997e1d708af4cf47b453dcf7"),
("train-labels-idx1-ubyte.gz", "e144d726b3acfaa3e44228e80efcd344"),
("t10k-images-idx3-ubyte.gz", "5c965bf0a639b31b8f53240b1b52f4d7"),
("t10k-labels-idx1-ubyte.gz", "7320c461ea6c1c855c0b718fb2a4b134"),
]
classes = ["o", "ki", "su", "tsu", "na", "ha", "ma", "ya", "re", "wo"]
class EMNIST(MNIST):
"""`EMNIST <https://www.westernsydney.edu.au/bens/home/reproducible_research/emnist>`_ Dataset.
Args:
root (string): Root directory of dataset where ``EMNIST/raw/train-images-idx3-ubyte``
and ``EMNIST/raw/t10k-images-idx3-ubyte`` exist.
split (string): The dataset has 6 different splits: ``byclass``, ``bymerge``,
``balanced``, ``letters``, ``digits`` and ``mnist``. This argument specifies
which one to use.
train (bool, optional): If True, creates dataset from ``training.pt``,
otherwise from ``test.pt``.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""
url = "https://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zip"
md5 = "58c8d27c78d21e728a6bc7b3cc06412e"
splits = ("byclass", "bymerge", "balanced", "letters", "digits", "mnist")
# Merged Classes assumes Same structure for both uppercase and lowercase version
_merged_classes = {"c", "i", "j", "k", "l", "m", "o", "p", "s", "u", "v", "w", "x", "y", "z"}
_all_classes = set(string.digits + string.ascii_letters)
classes_split_dict = {
"byclass": sorted(list(_all_classes)),
"bymerge": sorted(list(_all_classes - _merged_classes)),
"balanced": sorted(list(_all_classes - _merged_classes)),
"letters": ["N/A"] + list(string.ascii_lowercase),
"digits": list(string.digits),
"mnist": list(string.digits),
}
def __init__(self, root: str, split: str, **kwargs: Any) -> None:
self.split = verify_str_arg(split, "split", self.splits)
self.training_file = self._training_file(split)
self.test_file = self._test_file(split)
super().__init__(root, **kwargs)
self.classes = self.classes_split_dict[self.split]
@staticmethod
def _training_file(split) -> str:
return f"training_{split}.pt"
@staticmethod
def _test_file(split) -> str:
return f"test_{split}.pt"
@property
def _file_prefix(self) -> str:
return f"emnist-{self.split}-{'train' if self.train else 'test'}"
@property
def images_file(self) -> str:
return os.path.join(self.raw_folder, f"{self._file_prefix}-images-idx3-ubyte")
@property
def labels_file(self) -> str:
return os.path.join(self.raw_folder, f"{self._file_prefix}-labels-idx1-ubyte")
def _load_data(self):
return read_image_file(self.images_file), read_label_file(self.labels_file)
def _check_exists(self) -> bool:
return all(check_integrity(file) for file in (self.images_file, self.labels_file))
def download(self) -> None:
"""Download the EMNIST data if it doesn't exist already."""
if self._check_exists():
return
os.makedirs(self.raw_folder, exist_ok=True)
download_and_extract_archive(self.url, download_root=self.raw_folder, md5=self.md5)
gzip_folder = os.path.join(self.raw_folder, "gzip")
for gzip_file in os.listdir(gzip_folder):
if gzip_file.endswith(".gz"):
extract_archive(os.path.join(gzip_folder, gzip_file), self.raw_folder)
shutil.rmtree(gzip_folder)
class QMNIST(MNIST):
"""`QMNIST <https://github.com/facebookresearch/qmnist>`_ Dataset.
Args:
root (string): Root directory of dataset whose ``raw``
subdir contains binary files of the datasets.
what (string,optional): Can be 'train', 'test', 'test10k',
'test50k', or 'nist' for respectively the mnist compatible
training set, the 60k qmnist testing set, the 10k qmnist
examples that match the mnist testing set, the 50k
remaining qmnist testing examples, or all the nist
digits. The default is to select 'train' or 'test'
according to the compatibility argument 'train'.
compat (bool,optional): A boolean that says whether the target
for each example is class number (for compatibility with
the MNIST dataloader) or a torch vector containing the
full qmnist information. Default=True.
download (bool, optional): If True, downloads the dataset from
the internet and puts it in root directory. If dataset is
already downloaded, it is not downloaded again.
transform (callable, optional): A function/transform that
takes in an PIL image and returns a transformed
version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform
that takes in the target and transforms it.
train (bool,optional,compatibility): When argument 'what' is
not specified, this boolean decides whether to load the
training set ot the testing set. Default: True.
"""
subsets = {"train": "train", "test": "test", "test10k": "test", "test50k": "test", "nist": "nist"}
resources: Dict[str, List[Tuple[str, str]]] = { # type: ignore[assignment]
"train": [
(
"https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-images-idx3-ubyte.gz",
"ed72d4157d28c017586c42bc6afe6370",
),
(
"https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-labels-idx2-int.gz",
"0058f8dd561b90ffdd0f734c6a30e5e4",
),
],
"test": [
(
"https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-images-idx3-ubyte.gz",
"1394631089c404de565df7b7aeaf9412",
),
(
"https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-labels-idx2-int.gz",
"5b5b05890a5e13444e108efe57b788aa",
),
],
"nist": [
(
"https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-images-idx3-ubyte.xz",
"7f124b3b8ab81486c9d8c2749c17f834",
),
(
"https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-labels-idx2-int.xz",
"5ed0e788978e45d4a8bd4b7caec3d79d",
),
],
}
classes = [
"0 - zero",
"1 - one",
"2 - two",
"3 - three",
"4 - four",
"5 - five",
"6 - six",
"7 - seven",
"8 - eight",
"9 - nine",
]
def __init__(
self, root: str, what: Optional[str] = None, compat: bool = True, train: bool = True, **kwargs: Any
) -> None:
if what is None:
what = "train" if train else "test"
self.what = verify_str_arg(what, "what", tuple(self.subsets.keys()))
self.compat = compat
self.data_file = what + ".pt"
self.training_file = self.data_file
self.test_file = self.data_file
super().__init__(root, train, **kwargs)
@property
def images_file(self) -> str:
(url, _), _ = self.resources[self.subsets[self.what]]
return os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0])
@property
def labels_file(self) -> str:
_, (url, _) = self.resources[self.subsets[self.what]]
return os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0])
def _check_exists(self) -> bool:
return all(check_integrity(file) for file in (self.images_file, self.labels_file))
def _load_data(self):
data = read_sn3_pascalvincent_tensor(self.images_file)
if data.dtype != torch.uint8:
raise TypeError(f"data should be of dtype torch.uint8 instead of {data.dtype}")
if data.ndimension() != 3:
raise ValueError("data should have 3 dimensions instead of {data.ndimension()}")
targets = read_sn3_pascalvincent_tensor(self.labels_file).long()
if targets.ndimension() != 2:
raise ValueError(f"targets should have 2 dimensions instead of {targets.ndimension()}")
if self.what == "test10k":
data = data[0:10000, :, :].clone()
targets = targets[0:10000, :].clone()
elif self.what == "test50k":
data = data[10000:, :, :].clone()
targets = targets[10000:, :].clone()
return data, targets
def download(self) -> None:
"""Download the QMNIST data if it doesn't exist already.
Note that we only download what has been asked for (argument 'what').
"""
if self._check_exists():
return
os.makedirs(self.raw_folder, exist_ok=True)
split = self.resources[self.subsets[self.what]]
for url, md5 in split:
download_and_extract_archive(url, self.raw_folder, md5=md5)
def __getitem__(self, index: int) -> Tuple[Any, Any]:
# redefined to handle the compat flag
img, target = self.data[index], self.targets[index]
img = Image.fromarray(img.numpy(), mode="L")
if self.transform is not None:
img = self.transform(img)
if self.compat:
target = int(target[0])
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def extra_repr(self) -> str:
return f"Split: {self.what}"
def get_int(b: bytes) -> int:
return int(codecs.encode(b, "hex"), 16)
SN3_PASCALVINCENT_BITSMAP = {
8: torch.uint8,
9: torch.int8,
11: torch.int16,
12: torch.int32,
13: torch.float32,
14: torch.float64,
}
TORCH_TYPE_BITS = {
torch.uint8: 8,
torch.int8: 8,
torch.int16: 16,
torch.int32: 32,
torch.float32: 32,
torch.float64: 64,
}
def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tensor:
"""Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh').
Argument may be a filename, compressed filename, or file object.
"""
# read
with open(path, "rb") as f:
data = f.read()
# parse
magic = get_int(data[0:4])
nd = magic % 256
ty = magic // 256
assert 1 <= nd <= 3
assert 8 <= ty <= 14
torch_type = SN3_PASCALVINCENT_BITSMAP[ty]
s = [get_int(data[4 * (i + 1) : 4 * (i + 2)]) for i in range(nd)]
num_bytes_per_value = TORCH_TYPE_BITS[torch_type] // 8
# The MNIST format uses the big endian byte order. If the system uses little endian byte order by default,
# we need to reverse the bytes before we can read them with torch.frombuffer().
needs_byte_reversal = sys.byteorder == "little" and num_bytes_per_value > 1
parsed = torch.frombuffer(bytearray(data), dtype=torch_type, offset=(4 * (nd + 1)))
if needs_byte_reversal:
parsed = parsed.flip(0)
assert parsed.shape[0] == np.prod(s) or not strict
return parsed.view(*s)
def read_label_file(path: str) -> torch.Tensor:
x = read_sn3_pascalvincent_tensor(path, strict=False)
if x.dtype != torch.uint8:
raise TypeError(f"x should be of dtype torch.uint8 instead of {x.dtype}")
if x.ndimension() != 1:
raise ValueError(f"x should have 1 dimension instead of {x.ndimension()}")
return x.long()
def read_image_file(path: str) -> torch.Tensor:
x = read_sn3_pascalvincent_tensor(path, strict=False)
if x.dtype != torch.uint8:
raise TypeError(f"x should be of dtype torch.uint8 instead of {x.dtype}")
if x.ndimension() != 3:
raise ValueError(f"x should have 3 dimension instead of {x.ndimension()}")
return x

View File

@ -1,522 +0,0 @@
import bz2
import contextlib
import gzip
import hashlib
import itertools
import lzma
import os
import os.path
import pathlib
import re
import sys
import tarfile
import urllib
import urllib.error
import urllib.request
import warnings
import zipfile
from typing import Any, Callable, Dict, IO, Iterable, Iterator, List, Optional, Tuple, TypeVar
from urllib.parse import urlparse
import numpy as np
import requests
import torch
from tqdm import tqdm
from .._internally_replaced_utils import _download_file_from_remote_location, _is_remote_location_available
USER_AGENT = "pytorch/vision"
def _save_response_content(
content: Iterator[bytes],
destination: str,
length: Optional[int] = None,
) -> None:
with open(destination, "wb") as fh, tqdm(total=length) as pbar:
for chunk in content:
# filter out keep-alive new chunks
if not chunk:
continue
fh.write(chunk)
pbar.update(len(chunk))
def _urlretrieve(url: str, filename: str, chunk_size: int = 1024 * 32) -> None:
with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response:
_save_response_content(iter(lambda: response.read(chunk_size), b""), filename, length=response.length)
def gen_bar_updater() -> Callable[[int, int, int], None]:
warnings.warn("The function `gen_bar_update` is deprecated since 0.13 and will be removed in 0.15.")
pbar = tqdm(total=None)
def bar_update(count, block_size, total_size):
if pbar.total is None and total_size:
pbar.total = total_size
progress_bytes = count * block_size
pbar.update(progress_bytes - pbar.n)
return bar_update
def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024) -> str:
# Setting the `usedforsecurity` flag does not change anything about the functionality, but indicates that we are
# not using the MD5 checksum for cryptography. This enables its usage in restricted environments like FIPS. Without
# it torchvision.datasets is unusable in these environments since we perform a MD5 check everywhere.
if sys.version_info >= (3, 9):
md5 = hashlib.md5(usedforsecurity=False)
else:
md5 = hashlib.md5()
with open(fpath, "rb") as f:
for chunk in iter(lambda: f.read(chunk_size), b""):
md5.update(chunk)
return md5.hexdigest()
def check_md5(fpath: str, md5: str, **kwargs: Any) -> bool:
return md5 == calculate_md5(fpath, **kwargs)
def check_integrity(fpath: str, md5: Optional[str] = None) -> bool:
if not os.path.isfile(fpath):
return False
if md5 is None:
return True
return check_md5(fpath, md5)
def _get_redirect_url(url: str, max_hops: int = 3) -> str:
initial_url = url
headers = {"Method": "HEAD", "User-Agent": USER_AGENT}
for _ in range(max_hops + 1):
with urllib.request.urlopen(urllib.request.Request(url, headers=headers)) as response:
if response.url == url or response.url is None:
return url
url = response.url
else:
raise RecursionError(
f"Request to {initial_url} exceeded {max_hops} redirects. The last redirect points to {url}."
)
def _get_google_drive_file_id(url: str) -> Optional[str]:
parts = urlparse(url)
if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None:
return None
match = re.match(r"/file/d/(?P<id>[^/]*)", parts.path)
if match is None:
return None
return match.group("id")
def download_url(
url: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None, max_redirect_hops: int = 3
) -> None:
"""Download a file from a url and place it in root.
Args:
url (str): URL to download file from
root (str): Directory to place downloaded file in
filename (str, optional): Name to save the file under. If None, use the basename of the URL
md5 (str, optional): MD5 checksum of the download. If None, do not check
max_redirect_hops (int, optional): Maximum number of redirect hops allowed
"""
root = os.path.expanduser(root)
if not filename:
filename = os.path.basename(url)
fpath = os.path.join(root, filename)
os.makedirs(root, exist_ok=True)
# check if file is already present locally
if check_integrity(fpath, md5):
print("Using downloaded and verified file: " + fpath)
return
if _is_remote_location_available():
_download_file_from_remote_location(fpath, url)
else:
# expand redirect chain if needed
url = _get_redirect_url(url, max_hops=max_redirect_hops)
# check if file is located on Google Drive
file_id = _get_google_drive_file_id(url)
if file_id is not None:
return download_file_from_google_drive(file_id, root, filename, md5)
# download the file
try:
print("Downloading " + url + " to " + fpath)
_urlretrieve(url, fpath)
except (urllib.error.URLError, OSError) as e: # type: ignore[attr-defined]
if url[:5] == "https":
url = url.replace("https:", "http:")
print("Failed download. Trying https -> http instead. Downloading " + url + " to " + fpath)
_urlretrieve(url, fpath)
else:
raise e
# check integrity of downloaded file
if not check_integrity(fpath, md5):
raise RuntimeError("File not found or corrupted.")
def list_dir(root: str, prefix: bool = False) -> List[str]:
"""List all directories at a given root
Args:
root (str): Path to directory whose folders need to be listed
prefix (bool, optional): If true, prepends the path to each result, otherwise
only returns the name of the directories found
"""
root = os.path.expanduser(root)
directories = [p for p in os.listdir(root) if os.path.isdir(os.path.join(root, p))]
if prefix is True:
directories = [os.path.join(root, d) for d in directories]
return directories
def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]:
"""List all files ending with a suffix at a given root
Args:
root (str): Path to directory whose folders need to be listed
suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png').
It uses the Python "str.endswith" method and is passed directly
prefix (bool, optional): If true, prepends the path to each result, otherwise
only returns the name of the files found
"""
root = os.path.expanduser(root)
files = [p for p in os.listdir(root) if os.path.isfile(os.path.join(root, p)) and p.endswith(suffix)]
if prefix is True:
files = [os.path.join(root, d) for d in files]
return files
def _extract_gdrive_api_response(response, chunk_size: int = 32 * 1024) -> Tuple[bytes, Iterator[bytes]]:
content = response.iter_content(chunk_size)
first_chunk = None
# filter out keep-alive new chunks
while not first_chunk:
first_chunk = next(content)
content = itertools.chain([first_chunk], content)
try:
match = re.search("<title>Google Drive - (?P<api_response>.+?)</title>", first_chunk.decode())
api_response = match["api_response"] if match is not None else None
except UnicodeDecodeError:
api_response = None
return api_response, content
def download_file_from_google_drive(file_id: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None):
"""Download a Google Drive file from and place it in root.
Args:
file_id (str): id of file to be downloaded
root (str): Directory to place downloaded file in
filename (str, optional): Name to save the file under. If None, use the id of the file.
md5 (str, optional): MD5 checksum of the download. If None, do not check
"""
# Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url
root = os.path.expanduser(root)
if not filename:
filename = file_id
fpath = os.path.join(root, filename)
os.makedirs(root, exist_ok=True)
if check_integrity(fpath, md5):
print(f"Using downloaded {'and verified ' if md5 else ''}file: {fpath}")
return
url = "https://drive.google.com/uc"
params = dict(id=file_id, export="download")
with requests.Session() as session:
response = session.get(url, params=params, stream=True)
for key, value in response.cookies.items():
if key.startswith("download_warning"):
token = value
break
else:
api_response, content = _extract_gdrive_api_response(response)
token = "t" if api_response == "Virus scan warning" else None
if token is not None:
response = session.get(url, params=dict(params, confirm=token), stream=True)
api_response, content = _extract_gdrive_api_response(response)
if api_response == "Quota exceeded":
raise RuntimeError(
f"The daily quota of the file {filename} is exceeded and it "
f"can't be downloaded. This is a limitation of Google Drive "
f"and can only be overcome by trying again later."
)
_save_response_content(content, fpath)
# In case we deal with an unhandled GDrive API response, the file should be smaller than 10kB and contain only text
if os.stat(fpath).st_size < 10 * 1024:
with contextlib.suppress(UnicodeDecodeError), open(fpath) as fh:
text = fh.read()
# Regular expression to detect HTML. Copied from https://stackoverflow.com/a/70585604
if re.search(r"</?\s*[a-z-][^>]*\s*>|(&(?:[\w\d]+|#\d+|#x[a-f\d]+);)", text):
warnings.warn(
f"We detected some HTML elements in the downloaded file. "
f"This most likely means that the download triggered an unhandled API response by GDrive. "
f"Please report this to torchvision at https://github.com/pytorch/vision/issues including "
f"the response:\n\n{text}"
)
if md5 and not check_md5(fpath, md5):
raise RuntimeError(
f"The MD5 checksum of the download file {fpath} does not match the one on record."
f"Please delete the file and try again. "
f"If the issue persists, please report this to torchvision at https://github.com/pytorch/vision/issues."
)
def _extract_tar(from_path: str, to_path: str, compression: Optional[str]) -> None:
with tarfile.open(from_path, f"r:{compression[1:]}" if compression else "r") as tar:
tar.extractall(to_path)
_ZIP_COMPRESSION_MAP: Dict[str, int] = {
".bz2": zipfile.ZIP_BZIP2,
".xz": zipfile.ZIP_LZMA,
}
def _extract_zip(from_path: str, to_path: str, compression: Optional[str]) -> None:
with zipfile.ZipFile(
from_path, "r", compression=_ZIP_COMPRESSION_MAP[compression] if compression else zipfile.ZIP_STORED
) as zip:
zip.extractall(to_path)
_ARCHIVE_EXTRACTORS: Dict[str, Callable[[str, str, Optional[str]], None]] = {
".tar": _extract_tar,
".zip": _extract_zip,
}
_COMPRESSED_FILE_OPENERS: Dict[str, Callable[..., IO]] = {
".bz2": bz2.open,
".gz": gzip.open,
".xz": lzma.open,
}
_FILE_TYPE_ALIASES: Dict[str, Tuple[Optional[str], Optional[str]]] = {
".tbz": (".tar", ".bz2"),
".tbz2": (".tar", ".bz2"),
".tgz": (".tar", ".gz"),
}
def _detect_file_type(file: str) -> Tuple[str, Optional[str], Optional[str]]:
"""Detect the archive type and/or compression of a file.
Args:
file (str): the filename
Returns:
(tuple): tuple of suffix, archive type, and compression
Raises:
RuntimeError: if file has no suffix or suffix is not supported
"""
suffixes = pathlib.Path(file).suffixes
if not suffixes:
raise RuntimeError(
f"File '{file}' has no suffixes that could be used to detect the archive type and compression."
)
suffix = suffixes[-1]
# check if the suffix is a known alias
if suffix in _FILE_TYPE_ALIASES:
return (suffix, *_FILE_TYPE_ALIASES[suffix])
# check if the suffix is an archive type
if suffix in _ARCHIVE_EXTRACTORS:
return suffix, suffix, None
# check if the suffix is a compression
if suffix in _COMPRESSED_FILE_OPENERS:
# check for suffix hierarchy
if len(suffixes) > 1:
suffix2 = suffixes[-2]
# check if the suffix2 is an archive type
if suffix2 in _ARCHIVE_EXTRACTORS:
return suffix2 + suffix, suffix2, suffix
return suffix, None, suffix
valid_suffixes = sorted(set(_FILE_TYPE_ALIASES) | set(_ARCHIVE_EXTRACTORS) | set(_COMPRESSED_FILE_OPENERS))
raise RuntimeError(f"Unknown compression or archive type: '{suffix}'.\nKnown suffixes are: '{valid_suffixes}'.")
def _decompress(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str:
r"""Decompress a file.
The compression is automatically detected from the file name.
Args:
from_path (str): Path to the file to be decompressed.
to_path (str): Path to the decompressed file. If omitted, ``from_path`` without compression extension is used.
remove_finished (bool): If ``True``, remove the file after the extraction.
Returns:
(str): Path to the decompressed file.
"""
suffix, archive_type, compression = _detect_file_type(from_path)
if not compression:
raise RuntimeError(f"Couldn't detect a compression from suffix {suffix}.")
if to_path is None:
to_path = from_path.replace(suffix, archive_type if archive_type is not None else "")
# We don't need to check for a missing key here, since this was already done in _detect_file_type()
compressed_file_opener = _COMPRESSED_FILE_OPENERS[compression]
with compressed_file_opener(from_path, "rb") as rfh, open(to_path, "wb") as wfh:
wfh.write(rfh.read())
if remove_finished:
os.remove(from_path)
return to_path
def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str:
"""Extract an archive.
The archive type and a possible compression is automatically detected from the file name. If the file is compressed
but not an archive the call is dispatched to :func:`decompress`.
Args:
from_path (str): Path to the file to be extracted.
to_path (str): Path to the directory the file will be extracted to. If omitted, the directory of the file is
used.
remove_finished (bool): If ``True``, remove the file after the extraction.
Returns:
(str): Path to the directory the file was extracted to.
"""
if to_path is None:
to_path = os.path.dirname(from_path)
suffix, archive_type, compression = _detect_file_type(from_path)
if not archive_type:
return _decompress(
from_path,
os.path.join(to_path, os.path.basename(from_path).replace(suffix, "")),
remove_finished=remove_finished,
)
# We don't need to check for a missing key here, since this was already done in _detect_file_type()
extractor = _ARCHIVE_EXTRACTORS[archive_type]
extractor(from_path, to_path, compression)
if remove_finished:
os.remove(from_path)
return to_path
def download_and_extract_archive(
url: str,
download_root: str,
extract_root: Optional[str] = None,
filename: Optional[str] = None,
md5: Optional[str] = None,
remove_finished: bool = False,
) -> None:
download_root = os.path.expanduser(download_root)
if extract_root is None:
extract_root = download_root
if not filename:
filename = os.path.basename(url)
download_url(url, download_root, filename, md5)
archive = os.path.join(download_root, filename)
print(f"Extracting {archive} to {extract_root}")
extract_archive(archive, extract_root, remove_finished)
def iterable_to_str(iterable: Iterable) -> str:
return "'" + "', '".join([str(item) for item in iterable]) + "'"
T = TypeVar("T", str, bytes)
def verify_str_arg(
value: T,
arg: Optional[str] = None,
valid_values: Optional[Iterable[T]] = None,
custom_msg: Optional[str] = None,
) -> T:
if not isinstance(value, torch._six.string_classes):
if arg is None:
msg = "Expected type str, but got type {type}."
else:
msg = "Expected type str for argument {arg}, but got type {type}."
msg = msg.format(type=type(value), arg=arg)
raise ValueError(msg)
if valid_values is None:
return value
if value not in valid_values:
if custom_msg is not None:
msg = custom_msg
else:
msg = "Unknown value '{value}' for argument {arg}. Valid values are {{{valid_values}}}."
msg = msg.format(value=value, arg=arg, valid_values=iterable_to_str(valid_values))
raise ValueError(msg)
return value
def _read_pfm(file_name: str, slice_channels: int = 2) -> np.ndarray:
"""Read file in .pfm format. Might contain either 1 or 3 channels of data.
Args:
file_name (str): Path to the file.
slice_channels (int): Number of channels to slice out of the file.
Useful for reading different data formats stored in .pfm files: Optical Flows, Stereo Disparity Maps, etc.
"""
with open(file_name, "rb") as f:
header = f.readline().rstrip()
if header not in [b"PF", b"Pf"]:
raise ValueError("Invalid PFM file")
dim_match = re.match(rb"^(\d+)\s(\d+)\s$", f.readline())
if not dim_match:
raise Exception("Malformed PFM header.")
w, h = (int(dim) for dim in dim_match.groups())
scale = float(f.readline().rstrip())
if scale < 0: # little-endian
endian = "<"
scale = -scale
else:
endian = ">" # big-endian
data = np.fromfile(f, dtype=endian + "f")
pfm_channels = 3 if header == b"PF" else 1
data = data.reshape(h, w, pfm_channels).transpose(2, 0, 1)
data = np.flip(data, axis=1) # flip on h dimension
data = data[:slice_channels, :, :]
return data.astype(np.float32)

View File

@ -1,104 +0,0 @@
import os
from typing import Any, Callable, List, Optional, Tuple
import torch
import torch.utils.data as data
from ..utils import _log_api_usage_once
class VisionDataset(data.Dataset):
"""
Base Class For making datasets which are compatible with torchvision.
It is necessary to override the ``__getitem__`` and ``__len__`` method.
Args:
root (string): Root directory of dataset.
transforms (callable, optional): A function/transforms that takes in
an image and a label and returns the transformed versions of both.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
.. note::
:attr:`transforms` and the combination of :attr:`transform` and :attr:`target_transform` are mutually exclusive.
"""
_repr_indent = 4
def __init__(
self,
root: str,
transforms: Optional[Callable] = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
self.root = root
has_transforms = transforms is not None
has_separate_transform = transform is not None or target_transform is not None
if has_transforms and has_separate_transform:
raise ValueError("Only transforms or transform/target_transform can be passed as argument")
# for backwards-compatibility
self.transform = transform
self.target_transform = target_transform
if has_separate_transform:
transforms = StandardTransform(transform, target_transform)
self.transforms = transforms
def __getitem__(self, index: int) -> Any:
"""
Args:
index (int): Index
Returns:
(Any): Sample and meta data, optionally transformed by the respective transforms.
"""
raise NotImplementedError
def __len__(self) -> int:
raise NotImplementedError
def __repr__(self) -> str:
head = "Dataset " + self.__class__.__name__
body = [f"Number of datapoints: {self.__len__()}"]
if self.root is not None:
body.append(f"Root location: {self.root}")
body += self.extra_repr().splitlines()
if hasattr(self, "transforms") and self.transforms is not None:
body += [repr(self.transforms)]
lines = [head] + [" " * self._repr_indent + line for line in body]
return "\n".join(lines)
def _format_transform_repr(self, transform: Callable, head: str) -> List[str]:
lines = transform.__repr__().splitlines()
return [f"{head}{lines[0]}"] + ["{}{}".format(" " * len(head), line) for line in lines[1:]]
def extra_repr(self) -> str:
return ""
class StandardTransform:
def __init__(self, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None) -> None:
self.transform = transform
self.target_transform = target_transform
def __call__(self, input: Any, target: Any) -> Tuple[Any, Any]:
if self.transform is not None:
input = self.transform(input)
if self.target_transform is not None:
target = self.target_transform(target)
return input, target
def _format_transform_repr(self, transform: Callable, head: str) -> List[str]:
lines = transform.__repr__().splitlines()
return [f"{head}{lines[0]}"] + ["{}{}".format(" " * len(head), line) for line in lines[1:]]
def __repr__(self) -> str:
body = [self.__class__.__name__]
if self.transform is not None:
body += self._format_transform_repr(self.transform, "Transform: ")
if self.target_transform is not None:
body += self._format_transform_repr(self.target_transform, "Target transform: ")
return "\n".join(body)

View File

@ -1 +0,0 @@
from jittor.transform import *

View File

@ -1,582 +0,0 @@
import collections
import math
import pathlib
import warnings
from itertools import repeat
from types import FunctionType
from typing import Any, BinaryIO, List, Optional, Tuple, Union
import numpy as np
import torch
from PIL import Image, ImageColor, ImageDraw, ImageFont
__all__ = [
"make_grid",
"save_image",
"draw_bounding_boxes",
"draw_segmentation_masks",
"draw_keypoints",
"flow_to_image",
]
@torch.no_grad()
def make_grid(
tensor: Union[torch.Tensor, List[torch.Tensor]],
nrow: int = 8,
padding: int = 2,
normalize: bool = False,
value_range: Optional[Tuple[int, int]] = None,
scale_each: bool = False,
pad_value: float = 0.0,
**kwargs,
) -> torch.Tensor:
"""
Make a grid of images.
Args:
tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W)
or a list of images all of the same size.
nrow (int, optional): Number of images displayed in each row of the grid.
The final grid size is ``(B / nrow, nrow)``. Default: ``8``.
padding (int, optional): amount of padding. Default: ``2``.
normalize (bool, optional): If True, shift the image to the range (0, 1),
by the min and max values specified by ``value_range``. Default: ``False``.
value_range (tuple, optional): tuple (min, max) where min and max are numbers,
then these numbers are used to normalize the image. By default, min and max
are computed from the tensor.
range (tuple. optional):
.. warning::
This parameter was deprecated in ``0.12`` and will be removed in ``0.14``. Please use ``value_range``
instead.
scale_each (bool, optional): If ``True``, scale each image in the batch of
images separately rather than the (min, max) over all images. Default: ``False``.
pad_value (float, optional): Value for the padded pixels. Default: ``0``.
Returns:
grid (Tensor): the tensor containing grid of images.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(make_grid)
if not torch.is_tensor(tensor):
if isinstance(tensor, list):
for t in tensor:
if not torch.is_tensor(t):
raise TypeError(f"tensor or list of tensors expected, got a list containing {type(t)}")
else:
raise TypeError(f"tensor or list of tensors expected, got {type(tensor)}")
if "range" in kwargs.keys():
warnings.warn(
"The parameter 'range' is deprecated since 0.12 and will be removed in 0.14. "
"Please use 'value_range' instead."
)
value_range = kwargs["range"]
# if list of tensors, convert to a 4D mini-batch Tensor
if isinstance(tensor, list):
tensor = torch.stack(tensor, dim=0)
if tensor.dim() == 2: # single image H x W
tensor = tensor.unsqueeze(0)
if tensor.dim() == 3: # single image
if tensor.size(0) == 1: # if single-channel, convert to 3-channel
tensor = torch.cat((tensor, tensor, tensor), 0)
tensor = tensor.unsqueeze(0)
if tensor.dim() == 4 and tensor.size(1) == 1: # single-channel images
tensor = torch.cat((tensor, tensor, tensor), 1)
if normalize is True:
tensor = tensor.clone() # avoid modifying tensor in-place
if value_range is not None and not isinstance(value_range, tuple):
raise TypeError("value_range has to be a tuple (min, max) if specified. min and max are numbers")
def norm_ip(img, low, high):
img.clamp_(min=low, max=high)
img.sub_(low).div_(max(high - low, 1e-5))
def norm_range(t, value_range):
if value_range is not None:
norm_ip(t, value_range[0], value_range[1])
else:
norm_ip(t, float(t.min()), float(t.max()))
if scale_each is True:
for t in tensor: # loop over mini-batch dimension
norm_range(t, value_range)
else:
norm_range(tensor, value_range)
if not isinstance(tensor, torch.Tensor):
raise TypeError("tensor should be of type torch.Tensor")
if tensor.size(0) == 1:
return tensor.squeeze(0)
# make the mini-batch of images into a grid
nmaps = tensor.size(0)
xmaps = min(nrow, nmaps)
ymaps = int(math.ceil(float(nmaps) / xmaps))
height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding)
num_channels = tensor.size(1)
grid = tensor.new_full((num_channels, height * ymaps + padding, width * xmaps + padding), pad_value)
k = 0
for y in range(ymaps):
for x in range(xmaps):
if k >= nmaps:
break
# Tensor.copy_() is a valid method but seems to be missing from the stubs
# https://pytorch.org/docs/stable/tensors.html#torch.Tensor.copy_
grid.narrow(1, y * height + padding, height - padding).narrow( # type: ignore[attr-defined]
2, x * width + padding, width - padding
).copy_(tensor[k])
k = k + 1
return grid
@torch.no_grad()
def save_image(
tensor: Union[torch.Tensor, List[torch.Tensor]],
fp: Union[str, pathlib.Path, BinaryIO],
format: Optional[str] = None,
**kwargs,
) -> None:
"""
Save a given Tensor into an image file.
Args:
tensor (Tensor or list): Image to be saved. If given a mini-batch tensor,
saves the tensor as a grid of images by calling ``make_grid``.
fp (string or file object): A filename or a file object
format(Optional): If omitted, the format to use is determined from the filename extension.
If a file object was used instead of a filename, this parameter should always be used.
**kwargs: Other arguments are documented in ``make_grid``.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(save_image)
grid = make_grid(tensor, **kwargs)
# Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
im = Image.fromarray(ndarr)
im.save(fp, format=format)
@torch.no_grad()
def draw_bounding_boxes(
image: torch.Tensor,
boxes: torch.Tensor,
labels: Optional[List[str]] = None,
colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None,
fill: Optional[bool] = False,
width: int = 1,
font: Optional[str] = None,
font_size: Optional[int] = None,
) -> torch.Tensor:
"""
Draws bounding boxes on given image.
The values of the input image should be uint8 between 0 and 255.
If fill is True, Resulting Tensor should be saved as PNG image.
Args:
image (Tensor): Tensor of shape (C x H x W) and dtype uint8.
boxes (Tensor): Tensor of size (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format. Note that
the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and
`0 <= ymin < ymax < H`.
labels (List[str]): List containing the labels of bounding boxes.
colors (color or list of colors, optional): List containing the colors
of the boxes or single color for all boxes. The color can be represented as
PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
By default, random colors are generated for boxes.
fill (bool): If `True` fills the bounding box with specified color.
width (int): Width of bounding box.
font (str): A filename containing a TrueType font. If the file is not found in this filename, the loader may
also search in other directories, such as the `fonts/` directory on Windows or `/Library/Fonts/`,
`/System/Library/Fonts/` and `~/Library/Fonts/` on macOS.
font_size (int): The requested font size in points.
Returns:
img (Tensor[C, H, W]): Image Tensor of dtype uint8 with bounding boxes plotted.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(draw_bounding_boxes)
if not isinstance(image, torch.Tensor):
raise TypeError(f"Tensor expected, got {type(image)}")
elif image.dtype != torch.uint8:
raise ValueError(f"Tensor uint8 expected, got {image.dtype}")
elif image.dim() != 3:
raise ValueError("Pass individual images, not batches")
elif image.size(0) not in {1, 3}:
raise ValueError("Only grayscale and RGB images are supported")
elif (boxes[:, 0] > boxes[:, 2]).any() or (boxes[:, 1] > boxes[:, 3]).any():
raise ValueError(
"Boxes need to be in (xmin, ymin, xmax, ymax) format. Use torchvision.ops.box_convert to convert them"
)
num_boxes = boxes.shape[0]
if num_boxes == 0:
warnings.warn("boxes doesn't contain any box. No box was drawn")
return image
if labels is None:
labels: Union[List[str], List[None]] = [None] * num_boxes # type: ignore[no-redef]
elif len(labels) != num_boxes:
raise ValueError(
f"Number of boxes ({num_boxes}) and labels ({len(labels)}) mismatch. Please specify labels for each box."
)
if colors is None:
colors = _generate_color_palette(num_boxes)
elif isinstance(colors, list):
if len(colors) < num_boxes:
raise ValueError(f"Number of colors ({len(colors)}) is less than number of boxes ({num_boxes}). ")
else: # colors specifies a single color for all boxes
colors = [colors] * num_boxes
colors = [(ImageColor.getrgb(color) if isinstance(color, str) else color) for color in colors]
if font is None:
if font_size is not None:
warnings.warn("Argument 'font_size' will be ignored since 'font' is not set.")
txt_font = ImageFont.load_default()
else:
txt_font = ImageFont.truetype(font=font, size=font_size or 10)
# Handle Grayscale images
if image.size(0) == 1:
image = torch.tile(image, (3, 1, 1))
ndarr = image.permute(1, 2, 0).cpu().numpy()
img_to_draw = Image.fromarray(ndarr)
img_boxes = boxes.to(torch.int64).tolist()
if fill:
draw = ImageDraw.Draw(img_to_draw, "RGBA")
else:
draw = ImageDraw.Draw(img_to_draw)
for bbox, color, label in zip(img_boxes, colors, labels): # type: ignore[arg-type]
if fill:
fill_color = color + (100,)
draw.rectangle(bbox, width=width, outline=color, fill=fill_color)
else:
draw.rectangle(bbox, width=width, outline=color)
if label is not None:
margin = width + 1
draw.text((bbox[0] + margin, bbox[1] + margin), label, fill=color, font=txt_font)
return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8)
@torch.no_grad()
def draw_segmentation_masks(
image: torch.Tensor,
masks: torch.Tensor,
alpha: float = 0.8,
colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None,
) -> torch.Tensor:
"""
Draws segmentation masks on given RGB image.
The values of the input image should be uint8 between 0 and 255.
Args:
image (Tensor): Tensor of shape (3, H, W) and dtype uint8.
masks (Tensor): Tensor of shape (num_masks, H, W) or (H, W) and dtype bool.
alpha (float): Float number between 0 and 1 denoting the transparency of the masks.
0 means full transparency, 1 means no transparency.
colors (color or list of colors, optional): List containing the colors
of the masks or single color for all masks. The color can be represented as
PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
By default, random colors are generated for each mask.
Returns:
img (Tensor[C, H, W]): Image Tensor, with segmentation masks drawn on top.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(draw_segmentation_masks)
if not isinstance(image, torch.Tensor):
raise TypeError(f"The image must be a tensor, got {type(image)}")
elif image.dtype != torch.uint8:
raise ValueError(f"The image dtype must be uint8, got {image.dtype}")
elif image.dim() != 3:
raise ValueError("Pass individual images, not batches")
elif image.size()[0] != 3:
raise ValueError("Pass an RGB image. Other Image formats are not supported")
if masks.ndim == 2:
masks = masks[None, :, :]
if masks.ndim != 3:
raise ValueError("masks must be of shape (H, W) or (batch_size, H, W)")
if masks.dtype != torch.bool:
raise ValueError(f"The masks must be of dtype bool. Got {masks.dtype}")
if masks.shape[-2:] != image.shape[-2:]:
raise ValueError("The image and the masks must have the same height and width")
num_masks = masks.size()[0]
if colors is not None and num_masks > len(colors):
raise ValueError(f"There are more masks ({num_masks}) than colors ({len(colors)})")
if num_masks == 0:
warnings.warn("masks doesn't contain any mask. No mask was drawn")
return image
if colors is None:
colors = _generate_color_palette(num_masks)
if not isinstance(colors, list):
colors = [colors]
if not isinstance(colors[0], (tuple, str)):
raise ValueError("colors must be a tuple or a string, or a list thereof")
if isinstance(colors[0], tuple) and len(colors[0]) != 3:
raise ValueError("It seems that you passed a tuple of colors instead of a list of colors")
out_dtype = torch.uint8
colors_ = []
for color in colors:
if isinstance(color, str):
color = ImageColor.getrgb(color)
colors_.append(torch.tensor(color, dtype=out_dtype))
img_to_draw = image.detach().clone()
# TODO: There might be a way to vectorize this
for mask, color in zip(masks, colors_):
img_to_draw[:, mask] = color[:, None]
out = image * (1 - alpha) + img_to_draw * alpha
return out.to(out_dtype)
@torch.no_grad()
def draw_keypoints(
image: torch.Tensor,
keypoints: torch.Tensor,
connectivity: Optional[List[Tuple[int, int]]] = None,
colors: Optional[Union[str, Tuple[int, int, int]]] = None,
radius: int = 2,
width: int = 3,
) -> torch.Tensor:
"""
Draws Keypoints on given RGB image.
The values of the input image should be uint8 between 0 and 255.
Args:
image (Tensor): Tensor of shape (3, H, W) and dtype uint8.
keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoints location for each of the N instances,
in the format [x, y].
connectivity (List[Tuple[int, int]]]): A List of tuple where,
each tuple contains pair of keypoints to be connected.
colors (str, Tuple): The color can be represented as
PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
radius (int): Integer denoting radius of keypoint.
width (int): Integer denoting width of line connecting keypoints.
Returns:
img (Tensor[C, H, W]): Image Tensor of dtype uint8 with keypoints drawn.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(draw_keypoints)
if not isinstance(image, torch.Tensor):
raise TypeError(f"The image must be a tensor, got {type(image)}")
elif image.dtype != torch.uint8:
raise ValueError(f"The image dtype must be uint8, got {image.dtype}")
elif image.dim() != 3:
raise ValueError("Pass individual images, not batches")
elif image.size()[0] != 3:
raise ValueError("Pass an RGB image. Other Image formats are not supported")
if keypoints.ndim != 3:
raise ValueError("keypoints must be of shape (num_instances, K, 2)")
ndarr = image.permute(1, 2, 0).cpu().numpy()
img_to_draw = Image.fromarray(ndarr)
draw = ImageDraw.Draw(img_to_draw)
img_kpts = keypoints.to(torch.int64).tolist()
for kpt_id, kpt_inst in enumerate(img_kpts):
for inst_id, kpt in enumerate(kpt_inst):
x1 = kpt[0] - radius
x2 = kpt[0] + radius
y1 = kpt[1] - radius
y2 = kpt[1] + radius
draw.ellipse([x1, y1, x2, y2], fill=colors, outline=None, width=0)
if connectivity:
for connection in connectivity:
start_pt_x = kpt_inst[connection[0]][0]
start_pt_y = kpt_inst[connection[0]][1]
end_pt_x = kpt_inst[connection[1]][0]
end_pt_y = kpt_inst[connection[1]][1]
draw.line(
((start_pt_x, start_pt_y), (end_pt_x, end_pt_y)),
width=width,
)
return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8)
# Flow visualization code adapted from https://github.com/tomrunia/OpticalFlow_Visualization
@torch.no_grad()
def flow_to_image(flow: torch.Tensor) -> torch.Tensor:
"""
Converts a flow to an RGB image.
Args:
flow (Tensor): Flow of shape (N, 2, H, W) or (2, H, W) and dtype torch.float.
Returns:
img (Tensor): Image Tensor of dtype uint8 where each color corresponds
to a given flow direction. Shape is (N, 3, H, W) or (3, H, W) depending on the input.
"""
if flow.dtype != torch.float:
raise ValueError(f"Flow should be of dtype torch.float, got {flow.dtype}.")
orig_shape = flow.shape
if flow.ndim == 3:
flow = flow[None] # Add batch dim
if flow.ndim != 4 or flow.shape[1] != 2:
raise ValueError(f"Input flow should have shape (2, H, W) or (N, 2, H, W), got {orig_shape}.")
max_norm = torch.sum(flow**2, dim=1).sqrt().max()
epsilon = torch.finfo((flow).dtype).eps
normalized_flow = flow / (max_norm + epsilon)
img = _normalized_flow_to_image(normalized_flow)
if len(orig_shape) == 3:
img = img[0] # Remove batch dim
return img
@torch.no_grad()
def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Tensor:
"""
Converts a batch of normalized flow to an RGB image.
Args:
normalized_flow (torch.Tensor): Normalized flow tensor of shape (N, 2, H, W)
Returns:
img (Tensor(N, 3, H, W)): Flow visualization image of dtype uint8.
"""
N, _, H, W = normalized_flow.shape
device = normalized_flow.device
flow_image = torch.zeros((N, 3, H, W), dtype=torch.uint8, device=device)
colorwheel = _make_colorwheel().to(device) # shape [55x3]
num_cols = colorwheel.shape[0]
norm = torch.sum(normalized_flow**2, dim=1).sqrt()
a = torch.atan2(-normalized_flow[:, 1, :, :], -normalized_flow[:, 0, :, :]) / torch.pi
fk = (a + 1) / 2 * (num_cols - 1)
k0 = torch.floor(fk).to(torch.long)
k1 = k0 + 1
k1[k1 == num_cols] = 0
f = fk - k0
for c in range(colorwheel.shape[1]):
tmp = colorwheel[:, c]
col0 = tmp[k0] / 255.0
col1 = tmp[k1] / 255.0
col = (1 - f) * col0 + f * col1
col = 1 - norm * (1 - col)
flow_image[:, c, :, :] = torch.floor(255 * col)
return flow_image
def _make_colorwheel() -> torch.Tensor:
"""
Generates a color wheel for optical flow visualization as presented in:
Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf.
Returns:
colorwheel (Tensor[55, 3]): Colorwheel Tensor.
"""
RY = 15
YG = 6
GC = 4
CB = 11
BM = 13
MR = 6
ncols = RY + YG + GC + CB + BM + MR
colorwheel = torch.zeros((ncols, 3))
col = 0
# RY
colorwheel[0:RY, 0] = 255
colorwheel[0:RY, 1] = torch.floor(255 * torch.arange(0, RY) / RY)
col = col + RY
# YG
colorwheel[col : col + YG, 0] = 255 - torch.floor(255 * torch.arange(0, YG) / YG)
colorwheel[col : col + YG, 1] = 255
col = col + YG
# GC
colorwheel[col : col + GC, 1] = 255
colorwheel[col : col + GC, 2] = torch.floor(255 * torch.arange(0, GC) / GC)
col = col + GC
# CB
colorwheel[col : col + CB, 1] = 255 - torch.floor(255 * torch.arange(CB) / CB)
colorwheel[col : col + CB, 2] = 255
col = col + CB
# BM
colorwheel[col : col + BM, 2] = 255
colorwheel[col : col + BM, 0] = torch.floor(255 * torch.arange(0, BM) / BM)
col = col + BM
# MR
colorwheel[col : col + MR, 2] = 255 - torch.floor(255 * torch.arange(MR) / MR)
colorwheel[col : col + MR, 0] = 255
return colorwheel
def _generate_color_palette(num_objects: int):
palette = torch.tensor([2**25 - 1, 2**15 - 1, 2**21 - 1])
return [tuple((i * palette) % 255) for i in range(num_objects)]
def _log_api_usage_once(obj: Any) -> None:
"""
Logs API usage(module and name) within an organization.
In a large ecosystem, it's often useful to track the PyTorch and
TorchVision APIs usage. This API provides the similar functionality to the
logging module in the Python stdlib. It can be used for debugging purpose
to log which methods are used and by default it is inactive, unless the user
manually subscribes a logger via the `SetAPIUsageLogger method <https://github.com/pytorch/pytorch/blob/eb3b9fe719b21fae13c7a7cf3253f970290a573e/c10/util/Logging.cpp#L114>`_.
Please note it is triggered only once for the same API call within a process.
It does not collect any data from open-source users since it is no-op by default.
For more information, please refer to
* PyTorch note: https://pytorch.org/docs/stable/notes/large_scale_deployments.html#api-usage-logging;
* Logging policy: https://github.com/pytorch/vision/issues/5052;
Args:
obj (class instance or method): an object to extract info from.
"""
pass
def _make_ntuple(x: Any, n: int) -> Tuple[Any, ...]:
"""
Make n-tuple from input x. If x is an iterable, then we just convert it to tuple.
Otherwise we will make a tuple of length n, all with value of x.
reference: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/utils.py#L8
Args:
x (Any): input value
n (int): length of the resulting tuple
"""
if isinstance(x, collections.abc.Iterable):
return tuple(x)
return tuple(repeat(x, n))

View File

@ -3537,38 +3537,4 @@ def mish(x, inplace=False):
def skip_init(module_cls, *args, **kw):
return module_cls(*args, **kw)
'''
Only for extern/acl
'''
class FlashAttention(Module):
def __init__(self,headnum,
layout = "BSND",
prefix = None,
qstart = None,
kvstart = None,
scale = 1.0,
prob = 1.0,
pretokens = 2147483647,
nexttokens = 2147483647,
innerprecise = 0,
sparsemode = 0,
psetype = 1):
self.headnum = headnum
self.layout = layout
self.prefix = prefix
self.qstart = qstart
self.kvstart = kvstart
self.scale = scale
self.prob = prob
self.pretokens = pretokens
self.nexttokens = nexttokens
self.innerprecise = innerprecise
self.sparsemode = sparsemode
self.psetype = psetype
def execute(self,q,k,v,
realshift = None,
dropMask = None,
paddingMask = None,
attenMask = None):
pass

View File

@ -2,6 +2,7 @@ import jittor as jt
from jittor.nn import ComplexNumber
import unittest
import numpy as np
from functools import partial
__skip_torch_test = False
try:
@ -10,6 +11,15 @@ except:
__skip_torch_test = True
class TestResultAndGrad:
def flatten_list(self, list_like):
results = []
if isinstance(list_like, (list, tuple)):
for x in list_like:
results.extend(self.flatten_list(x))
return results
else:
return [list_like]
def check_results(self, rlist1, rlist2):
assert len(rlist1) == len(rlist2)
for r1, r2 in zip(rlist1, rlist2):
@ -36,13 +46,21 @@ class TestResultAndGrad:
grads.append(g.detach().cpu().numpy())
return grads
def run_jittor_op(self, op, input_list, weights=None):
def _np_to_jittor(x:np.ndarray):
if x.dtype == np.complex64 or x.dtype == np.complex128:
nx = np.stack([np.real(x), np.imag(x)], axis=-1)
return ComplexNumber(jt.array(nx, dtype=jt.float32), is_concat_value=True)
elif x.dtype == np.float32 or x.dtype == np.float64:
return jt.array(x, dtype=jt.float32)
def run_jittor_op(self, op, input_list, weights=None, key_names=None, **kwargs):
def _np_to_jittor(x):
if isinstance(x, np.ndarray):
if x.dtype == np.complex64 or x.dtype == np.complex128:
nx = np.stack([np.real(x), np.imag(x)], axis=-1)
return ComplexNumber(jt.array(nx, dtype=jt.float32), is_concat_value=True)
elif x.dtype == np.float32 or x.dtype == np.float64:
return jt.array(x, dtype=jt.float32)
else:
assert False
elif isinstance(x, (list, tuple)):
nx = [_np_to_jittor(vx) for vx in x]
if isinstance(x, tuple):
return tuple(nx)
return nx
else:
assert False
def _jittor_to_np(x):
@ -51,11 +69,19 @@ class TestResultAndGrad:
elif isinstance(x, ComplexNumber):
return x.real.numpy() + 1j * x.imag.numpy()
assert False
ninput_list = [_np_to_jittor(x) for x in input_list]
output_list = op(*ninput_list)
if key_names != None:
assert len(ninput_list) == len(key_names)
nkwargs = kwargs.copy()
for k, v in zip(key_names, ninput_list):
nkwargs[k] = v
output_list = op(**nkwargs)
else:
output_list = op(*ninput_list, **kwargs)
if isinstance(output_list, (jt.Var, ComplexNumber)):
output_list = [output_list]
output_list = self.flatten_list(output_list)
losses = []
if weights is None:
weights = []
@ -73,15 +99,31 @@ class TestResultAndGrad:
output_list = [_jittor_to_np(x) for x in output_list]
return ninput_list, output_list, losses, weights
def run_torch_op(self, op, input_list, weights=None):
def _np_to_torch(x:np.ndarray):
return torch.from_numpy(x).requires_grad_(True)
def run_torch_op(self, op, input_list, weights=None, key_names=None, **kwargs):
def _np_to_torch(x):
if isinstance(x, np.ndarray):
return torch.from_numpy(x).requires_grad_(True)
elif isinstance(x, (list, tuple)):
nx = [_np_to_torch(vx) for vx in x]
if isinstance(x, tuple):
return tuple(nx)
return nx
else:
assert False
def _torch_to_np(x:torch.Tensor) -> np.ndarray:
return x.detach().cpu().numpy()
ninput_list = [_np_to_torch(x) for x in input_list]
output_list = op(*ninput_list)
if key_names != None:
assert len(ninput_list) == len(key_names)
nkwargs = kwargs.copy()
for k, v in zip(key_names, ninput_list):
nkwargs[k] = v
output_list = op(**nkwargs)
else:
output_list = op(*ninput_list, **kwargs)
if isinstance(output_list, torch.Tensor):
output_list = [output_list]
output_list = self.flatten_list(output_list)
losses = []
if weights is None:
weights = []
@ -99,10 +141,10 @@ class TestResultAndGrad:
output_list = [_torch_to_np(x) for x in output_list]
return ninput_list, output_list, losses, weights
def check_op_with_torch(self, jittor_op, torch_op, input_list, check_grad=True):
def check_op_with_torch(self, jittor_op, torch_op, input_list, check_grad=True, jittor_knames=None, torch_knames=None, **kwargs):
weights = None
jittor_input, jittor_output, jittor_losses, weights = self.run_jittor_op(jittor_op, input_list, weights)
torch_input, torch_output, torch_losses, weights = self.run_torch_op(torch_op, input_list, weights)
jittor_input, jittor_output, jittor_losses, weights = self.run_jittor_op(jittor_op, input_list, weights, key_names=jittor_knames, **kwargs)
torch_input, torch_output, torch_losses, weights = self.run_torch_op(torch_op, input_list, weights, key_names=torch_knames, **kwargs)
self.check_results(jittor_output, torch_output)
if check_grad:
@ -195,6 +237,249 @@ class TestComplexLinalg(unittest.TestCase, TestResultAndGrad):
inputs = [m1]
self.check_op_with_numpy(jt.linalg.svd, np.linalg.svd, inputs)
class TestTensordot(unittest.TestCase, TestResultAndGrad):
def random_complex_matrix(self, shape):
r = np.random.randn(*shape)
i = np.random.randn(*shape)
return r + 1j * i
def random_real_matrix(self, shape):
return np.random.randn(*shape)
def test_complex_tensordot_numberdim(self):
s1 = (3, 4, 5)
s2 = (4, 5, 6)
dims = 2
m1 = self.random_complex_matrix(s1)
m2 = self.random_complex_matrix(s2)
inputs = [m1, m2]
self.check_op_with_torch(jt.nn.tensordot, torch.tensordot, inputs, dims = dims)
def test_complex_tensordot_tupledim(self):
s1 = (3, 5, 4, 6)
s2 = (6, 4, 5, 3)
dims = ([2, 1, 3], [1, 2, 0])
m1 = self.random_complex_matrix(s1)
m2 = self.random_complex_matrix(s2)
inputs = [m1, m2]
self.check_op_with_torch(jt.nn.tensordot, torch.tensordot, inputs, dims = dims)
def test_real_tensordot_numberdim(self):
s1 = (3, 4, 5)
s2 = (4, 5, 6)
dims = 2
m1 = self.random_real_matrix(s1)
m2 = self.random_real_matrix(s2)
inputs = [m1, m2]
self.check_op_with_torch(jt.nn.tensordot, torch.tensordot, inputs, dims = dims)
def test_real_tensordot_tupledim(self):
s1 = (3, 5, 4, 6)
s2 = (6, 4, 5, 3)
dims = ([2, 1, 3], [1, 2, 0])
m1 = self.random_real_matrix(s1)
m2 = self.random_real_matrix(s2)
inputs = [m1, m2]
self.check_op_with_torch(jt.nn.tensordot, torch.tensordot, inputs, dims = dims)
class TestKron(unittest.TestCase, TestResultAndGrad):
def random_complex_matrix(self, shape):
r = np.random.randn(*shape)
i = np.random.randn(*shape)
return r + 1j * i
def random_real_matrix(self, shape):
return np.random.randn(*shape)
def test_complex_firstlarge(self):
s1 = (2, 3, 4)
s2 = (5, 2)
m1 = self.random_complex_matrix(s1)
m2 = self.random_complex_matrix(s2)
inputs = [m1, m2]
self.check_op_with_torch(jt.nn.kron, torch.kron, inputs)
def test_complex_second_large(self):
s1 = (2, 3)
s2 = (5, 2, 4)
m1 = self.random_complex_matrix(s1)
m2 = self.random_complex_matrix(s2)
inputs = [m1, m2]
self.check_op_with_torch(jt.nn.kron, torch.kron, inputs)
def test_real_firstlarge(self):
s1 = (2, 3, 4)
s2 = (5, 2)
m1 = self.random_real_matrix(s1)
m2 = self.random_real_matrix(s2)
inputs = [m1, m2]
self.check_op_with_torch(jt.nn.kron, torch.kron, inputs)
def test_real_second_large(self):
s1 = (2, 3)
s2 = (5, 2, 4)
m1 = self.random_real_matrix(s1)
m2 = self.random_real_matrix(s2)
inputs = [m1, m2]
self.check_op_with_torch(jt.nn.kron, torch.kron, inputs)
@unittest.skipIf(__skip_torch_test, "No Torch found")
class TestGradFunctional(unittest.TestCase, TestResultAndGrad):
def random_complex_matrix(self, shape):
r = np.random.randn(*shape)
i = np.random.randn(*shape)
return r + 1j * i
def random_real_matrix(self, shape):
return np.random.randn(*shape) * 0.0 + 1.0
def test_real_jvp_exp(self):
def exp_reducer(x):
return x.exp().sum(dim=1)
s1 = (5, 6)
m1 = self.random_real_matrix(s1)
m2 = self.random_real_matrix(s1)
inputs = [m1, m2]
self.check_op_with_torch(
partial(jt.gradfunctional.jvp, func=exp_reducer, create_graph=True),
partial(torch.autograd.functional.jvp, func=exp_reducer, create_graph=True),
inputs,
jittor_knames = ['inputs', 'v'],
torch_knames = ['inputs', 'v'],
check_grad=False)
def test_complex_jvp_exp(self):
def exp_reducer(x):
return x.exp().sum(1)
s1 = (5, 6)
m1 = self.random_complex_matrix(s1)
m2 = self.random_complex_matrix(s1)
inputs = [m1, m2]
self.check_op_with_torch(
partial(jt.gradfunctional.jvp, func=exp_reducer, create_graph=True),
partial(torch.autograd.functional.jvp, func=exp_reducer, create_graph=True),
inputs,
jittor_knames = ['inputs', 'v'],
torch_knames = ['inputs', 'v'],
check_grad=False,
)
def test_real_jvp_add(self):
w1, w2 = np.random.rand(), np.random.rand()
def adder(x, y):
return w1 * x + w2 * y
s1 = (5, 6)
m1 = self.random_real_matrix(s1)
m2 = self.random_real_matrix(s1)
m3 = self.random_real_matrix(s1)
m4 = self.random_real_matrix(s1)
inputs = [(m1, m2), (m3, m4)]
self.check_op_with_torch(
partial(jt.gradfunctional.jvp, func=adder, create_graph=True),
partial(torch.autograd.functional.jvp, func=adder, create_graph=True),
inputs,
jittor_knames = ['inputs', 'v'],
torch_knames = ['inputs', 'v'],
check_grad=False,
)
def test_complex_jvp_add(self):
w1r, w1i = np.random.rand(), np.random.rand()
w2r, w2i = np.random.rand(), np.random.rand()
def adder_pt(x, y):
return (w1r + 1j * w1i) * x + (w2r + 1j * w2i) * y
def adder_jt(x, y):
w1 = ComplexNumber(real=jt.array(w1r).reshape(1,1), imag = jt.array(w1i).reshape(1,1))
w2 = ComplexNumber(real=jt.array(w2r).reshape(1,1), imag = jt.array(w2i).reshape(1,1))
return w1 * x + w2 * y
s1 = (5, 6)
m1 = self.random_complex_matrix(s1)
m2 = self.random_complex_matrix(s1)
m3 = self.random_complex_matrix(s1)
m4 = self.random_complex_matrix(s1)
inputs = [(m1, m2), (m3, m4)]
self.check_op_with_torch(
partial(jt.gradfunctional.jvp, func=adder_jt, create_graph=True),
partial(torch.autograd.functional.jvp, func=adder_pt, create_graph=True),
inputs,
jittor_knames = ['inputs', 'v'],
torch_knames = ['inputs', 'v'],
check_grad=False,
)
def test_real_vjp_exp(self):
def exp_reducer(x):
return x.exp().sum(dim=1)
s1 = (5, 6)
s2 = (5,)
m1 = self.random_real_matrix(s1)
m2 = self.random_real_matrix(s2)
inputs = [m1, m2]
self.check_op_with_torch(
partial(jt.gradfunctional.vjp, func=exp_reducer),
partial(torch.autograd.functional.vjp, func=exp_reducer),
inputs,
jittor_knames = ['inputs', 'v'],
torch_knames = ['inputs', 'v'],
check_grad=False)
def test_complex_vjp_exp(self):
def exp_reducer(x):
return x.exp().sum(1)
s1 = (5, 6)
s2 = (5,)
m1 = self.random_real_matrix(s1)
m2 = self.random_real_matrix(s2)
inputs = [m1, m2]
self.check_op_with_torch(
partial(jt.gradfunctional.vjp, func=exp_reducer, create_graph=True),
partial(torch.autograd.functional.vjp, func=exp_reducer, create_graph=True),
inputs,
jittor_knames = ['inputs', 'v'],
torch_knames = ['inputs', 'v'],
check_grad=False,
)
def test_real_vjp_add(self):
w1, w2 = np.random.rand(), np.random.rand()
def adder(x, y):
return w1 * x + w2 * y
s1 = (5, 6)
m1 = self.random_real_matrix(s1)
m2 = self.random_real_matrix(s1)
m3 = self.random_real_matrix(s1)
inputs = [(m1, m2), m3]
self.check_op_with_torch(
partial(jt.gradfunctional.vjp, func=adder, create_graph=True),
partial(torch.autograd.functional.vjp, func=adder, create_graph=True),
inputs,
jittor_knames = ['inputs', 'v'],
torch_knames = ['inputs', 'v'],
check_grad=False,
)
def test_complex_vjp_add(self):
w1r, w1i = np.random.rand(), np.random.rand()
w2r, w2i = np.random.rand(), np.random.rand()
def adder_pt(x, y):
return (w1r + 1j * w1i) * x + (w2r + 1j * w2i) * y
def adder_jt(x, y):
w1 = ComplexNumber(real=jt.array(w1r).reshape(1,1), imag = jt.array(w1i).reshape(1,1))
w2 = ComplexNumber(real=jt.array(w2r).reshape(1,1), imag = jt.array(w2i).reshape(1,1))
return w1 * x + w2 * y
s1 = (5, 6)
m1 = self.random_complex_matrix(s1)
m2 = self.random_complex_matrix(s1)
m3 = self.random_complex_matrix(s1)
inputs = [(m1, m2), (m3)]
self.check_op_with_torch(
partial(jt.gradfunctional.vjp, func=adder_jt, create_graph=True),
partial(torch.autograd.functional.vjp, func=adder_pt, create_graph=True),
inputs,
jittor_knames = ['inputs', 'v'],
torch_knames = ['inputs', 'v'],
check_grad=False,
)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,112 @@
# ***************************************************************
# Copyright (c) 2023 Jittor. All Rights Reserved.
# Maintainers: Shizhan Lu <578752274@qq.com>.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import unittest
import jittor as jt
import numpy as np
from jittor.compile_extern import cusparse_ops
class TestSpmmCsrOp(unittest.TestCase):
@unittest.skipIf(not jt.has_cuda, "No CUDA support, skipping test")
@jt.flag_scope(use_cuda=1, lazy_execution=0)
def test_spmm_csr_forward_float32_int32(self):
x = jt.array([[3.0, 2.0, 1.0], [3.0, 2.0, 1.0], [3.0, 2.0, 1.0]], dtype="float32")
col_indices = jt.array([0, 1, 1, 2], dtype="int32")
row_offset = jt.array([0, 2, 3, 4], dtype="int32")
csr_weight = jt.array([3.0, 1.0, 4.0, 2.0], dtype="float32")
output = jt.zeros((3, 3), dtype="float32")
cusparse_ops.cusparse_spmmcsr(
output, x, col_indices, csr_weight, row_offset,
3, 3 ,False, False
).fetch_sync()
expected_output = np.array([
[12.0, 8.0, 4.0],
[12.0, 8.0, 4.0],
[6.0, 4.0, 2.0]
])
np.testing.assert_allclose(output.data, expected_output, atol=1e-5)
@unittest.skipIf(not jt.has_cuda, "No CUDA support, skipping test")
@jt.flag_scope(use_cuda=1, lazy_execution=0)
def test_spmm_csr_forward_float16_int32(self):
x = jt.array([[3.0, 2.0, 1.0], [3.0, 2.0, 1.0], [3.0, 2.0, 1.0]], dtype="float16")
col_indices = jt.array([0, 1, 1, 2], dtype="int32")
row_offset = jt.array([0, 2, 3, 4], dtype="int32")
csr_weight = jt.array([3.0, 1.0, 4.0, 2.0], dtype="float16")
output = jt.zeros((3, 3), dtype="float16")
cusparse_ops.cusparse_spmmcsr(
output, x, col_indices, csr_weight, row_offset,
3, 3,False, False
).fetch_sync()
expected_output = np.array([
[12.0, 8.0, 4.0],
[12.0, 8.0, 4.0],
[6.0, 4.0, 2.0]
], dtype="float16")
np.testing.assert_allclose(output.data, expected_output, atol=1e-5)
# @unittest.skipIf(not jt.has_cuda, "No CUDA support, skipping test")
# @jt.flag_scope(use_cuda=1, lazy_execution=0)
# def test_spmm_csr_forward_float64_int32(self):
# x = jt.array([[3.0, 2.0, 1.0], [3.0, 2.0, 1.0], [3.0, 2.0, 1.0]], dtype="float64")
# col_indices = jt.array([0, 1, 1, 2], dtype="int32")
# row_offset = jt.array([0, 2, 3, 4], dtype="int32")
# csr_weight = jt.array([3.0, 1.0, 4.0, 2.0], dtype="float64")
# output = jt.zeros((3, 3), dtype="float64")
# cusparse_ops.cusparse_spmmcsr(
# output, x, col_indices, csr_weight, row_offset,
# 3, 3,False, False
# ).fetch_sync()
# expected_output = np.array([
# [12.0, 8.0, 4.0],
# [12.0, 8.0, 4.0],
# [6.0, 4.0, 2.0]
# ], dtype="float64")
# np.testing.assert_allclose(output.data, expected_output, atol=1e-5)
@unittest.skipIf(not jt.has_cuda, "No CUDA support, skipping test")
@jt.flag_scope(use_cuda=1, lazy_execution=0)
def test_spmm_csr_forward_float32_int64(self):
x = jt.array([[3.0, 2.0, 1.0], [3.0, 2.0, 1.0], [3.0, 2.0, 1.0]], dtype="float32")
col_indices = jt.array([0, 1, 1, 2], dtype="int64")
row_offset = jt.array([0, 2, 3, 4], dtype="int64")
csr_weight = jt.array([3.0, 1.0, 4.0, 2.0], dtype="float32")
output = jt.zeros((3, 3), dtype="float32")
cusparse_ops.cusparse_spmmcsr(
output, x, col_indices, csr_weight, row_offset,
3, 3,False, False
).fetch_sync()
expected_output = np.array([
[12.0, 8.0, 4.0],
[12.0, 8.0, 4.0],
[6.0, 4.0, 2.0]
], dtype="float32")
np.testing.assert_allclose(output.data, expected_output, atol=1e-5)
@unittest.skipIf(not jt.has_cuda, "No CUDA support, skipping test")
@jt.flag_scope(use_cuda=1, lazy_execution=0)
def test_spmm_coo(self):
x=jt.array([[3.0, 2.0, 1.0],[4.0, 2.0, 2.0],[1.0, 2.0, 3.0]], dtype="float32")
edge_index=jt.array([[0,0,1,2],[1,2,2,1]],dtype="int32")
row_indices=edge_index[0,:]
col_indices=edge_index[1,:]
edge_weight = jt.array([1.0, 1.0, 1.0, 1.0], dtype="float32")
feature_dim=jt.size(x,1)
output=jt.zeros(3,feature_dim)
cusparse_ops.cusparse_spmmcoo(output,x,row_indices,col_indices,edge_weight,3,3,False, False).fetch_sync()
print("Output:", output)
expected_output = np.array([
[5.0, 4.0, 5.0],
[1.0, 2.0, 3.0],
[4.0, 2.0, 2.0]
], dtype="float32")
np.testing.assert_allclose(output.data, expected_output, atol=1e-5)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,63 @@
import unittest
import jittor as jt
import numpy as np
class TestFinfo(unittest.TestCase):
def test(self):
for dtype in ['float16', 'float32', 'float64']:
finfo = jt.finfo(dtype)
np_finfo = np.finfo(dtype)
assert finfo.bits == np_finfo.bits
assert finfo.eps == np_finfo.eps
assert finfo.max == np_finfo.max
assert finfo.min == np_finfo.min
assert finfo.nexp == np_finfo.nexp
assert finfo.nmant == np_finfo.nmant
assert finfo.iexp == np_finfo.iexp
assert finfo.precision == np_finfo.precision
assert finfo.resolution == np_finfo.resolution
assert finfo.tiny == np_finfo.tiny
for dtype_jt, dtype in [
(jt.float16, 'float16'),
(jt.float32, 'float32'),
(jt.float64, 'float64'),
]:
finfo = jt.finfo(dtype_jt)
np_finfo = np.finfo(dtype)
assert finfo.bits == np_finfo.bits
assert finfo.eps == np_finfo.eps
assert finfo.max == np_finfo.max
assert finfo.min == np_finfo.min
assert finfo.nexp == np_finfo.nexp
assert finfo.nmant == np_finfo.nmant
assert finfo.iexp == np_finfo.iexp
assert finfo.precision == np_finfo.precision
assert finfo.resolution == np_finfo.resolution
assert finfo.tiny == np_finfo.tiny
class TestIinfo(unittest.TestCase):
def test(self):
for dtype in ['int16', 'int32', 'int64']:
iinfo = jt.iinfo(dtype)
np_iinfo = np.iinfo(dtype)
assert iinfo.bits == np_iinfo.bits
assert iinfo.max == np_iinfo.max
assert iinfo.min == np_iinfo.min
assert iinfo.dtype == np.dtype(dtype)
for dtype_jt, dtype in [
(jt.int16, 'int16'),
(jt.int32, 'int32'),
(jt.int64, 'int64'),
]:
iinfo = jt.iinfo(dtype_jt)
np_iinfo = np.iinfo(dtype)
assert iinfo.bits == np_iinfo.bits
assert iinfo.max == np_iinfo.max
assert iinfo.min == np_iinfo.min
assert iinfo.dtype == np.dtype(dtype)
if __name__ == "__main__":
unittest.main()

View File

@ -46,6 +46,33 @@ class TestVarFunctions(unittest.TestCase):
np.testing.assert_allclose(jt_x.norm(2,1).numpy(), tc_x.norm(2,1).numpy(), atol=1e-6)
np.testing.assert_allclose(jt_x.norm(2,0).numpy(), tc_x.norm(2,0).numpy(), atol=1e-6)
def test_std_with_dim(self):
x=np.random.randn(100, 1000).astype(np.float32)
jt_x = jt.array(x)
tc_x = torch.from_numpy(x)
np.testing.assert_allclose(jt_x.std(dim=-1).numpy(), tc_x.std(dim=-1).numpy(), 1e-4)
np.testing.assert_allclose(jt_x.std(dim=0, keepdim=True).numpy(), tc_x.std(dim=0, keepdim=True).numpy(), 1e-4)
def test_diagonal(self):
x = np.reshape(np.arange(5*6*7*8), (5,6,7,8))
jt_x = jt.array(x)
tc_x = torch.from_numpy(x)
def __assert_equal(a:np.ndarray, b:np.ndarray, rtol=1e-6, atol=1e-6):
assert a.shape == b.shape, f"{a.shape}!={b.shape}"
np.testing.assert_allclose(a, b, rtol=rtol, atol=atol)
__assert_equal(jt.misc.diagonal(jt_x, 0, dim1=1, dim2=2).numpy(), tc_x.diagonal(offset=0, dim1=1, dim2=2).numpy())
__assert_equal(jt.misc.diagonal(jt_x, -1, dim1=1, dim2=2).numpy(), tc_x.diagonal(offset=-1, dim1=1, dim2=2).numpy())
__assert_equal(jt.misc.diagonal(jt_x, -2, dim1=1, dim2=2).numpy(), tc_x.diagonal(offset=-2, dim1=1, dim2=2).numpy())
__assert_equal(jt.misc.diagonal(jt_x, -6, dim1=1, dim2=2).numpy(), tc_x.diagonal(offset=-6, dim1=1, dim2=2).numpy())
__assert_equal(jt.misc.diagonal(jt_x, 1, dim1=1, dim2=2).numpy(), tc_x.diagonal(offset=1, dim1=1, dim2=2).numpy())
__assert_equal(jt.misc.diagonal(jt_x, 2, dim1=1, dim2=2).numpy(), tc_x.diagonal(offset=2, dim1=1, dim2=2).numpy())
__assert_equal(jt.misc.diagonal(jt_x, 7, dim1=1, dim2=2).numpy(), tc_x.diagonal(offset=7, dim1=1, dim2=2).numpy())
__assert_equal(jt.misc.diagonal(jt_x, 1, dim1=-1, dim2=2).numpy(), tc_x.diagonal(offset=1, dim1=-1, dim2=2).numpy())
__assert_equal(jt.misc.diagonal(jt_x, 1, dim1=-1, dim2=-2).numpy(), tc_x.diagonal(offset=1, dim1=-1, dim2=-2).numpy())
__assert_equal(jt.misc.diagonal(jt_x, 1, dim1=-2, dim2=-1).numpy(), tc_x.diagonal(offset=1, dim1=-2, dim2=-1).numpy())
__assert_equal(jt.misc.diagonal(jt_x, 1, dim1=0, dim2=-2).numpy(), tc_x.diagonal(offset=1, dim1=0, dim2=-2).numpy())
__assert_equal(jt.misc.diagonal(jt_x, 1, dim1=2, dim2=1).numpy(), tc_x.diagonal(offset=1, dim1=2, dim2=1).numpy())
if __name__ == "__main__":
unittest.main()