mirror of https://github.com/Jittor/Jittor
merge by JittorHW
This commit is contained in:
parent
b79ac22b05
commit
f8e44de79d
|
@ -0,0 +1,430 @@
|
|||
# import os
|
||||
# os.environ["FIX_TORCH_ERROR"] = "0"
|
||||
|
||||
# import jittor as jt
|
||||
# from jittor import *
|
||||
# from typing import Tuple
|
||||
|
||||
# org_int = int = type(1)
|
||||
# org_float = float = type(1.0)
|
||||
# org_bool = bool = type(True)
|
||||
|
||||
# import jtorch.compiler
|
||||
|
||||
# import jtorch_core
|
||||
# from jtorch_core import *
|
||||
|
||||
# device.__reduce__ = lambda self: (device, (self.type,))
|
||||
# device.__module__ = "jtorch"
|
||||
# jt.jittor_core.device = device
|
||||
|
||||
# def handle_dtype(args, kw, dtype):
|
||||
# def convert(x):
|
||||
# if isinstance(x, jt.Var):
|
||||
# return x.cast(dtype)
|
||||
# return x
|
||||
# if dtype is not None:
|
||||
# if args is not None:
|
||||
# if isinstance(args, (tuple,list)):
|
||||
# args = [ convert(a) for a in args ]
|
||||
# else:
|
||||
# args = convert(x)
|
||||
# if kw is not None:
|
||||
# kw = { k:convert(v) for k,v in kw.items() }
|
||||
# return args, kw
|
||||
|
||||
# def get_args_names(func):
|
||||
# import inspect
|
||||
# spec = inspect.getfullargspec(func)
|
||||
# return spec[0] + spec[4]
|
||||
|
||||
# def wrapper(func):
|
||||
# has_dtype = False
|
||||
# if hasattr(func, "__code__"):
|
||||
# has_dtype = "dtype" in get_args_names(func)
|
||||
# def inner(*args, **kw):
|
||||
# requires_grad = None
|
||||
# dtype = None
|
||||
# if "requires_grad" in kw:
|
||||
# requires_grad = kw["requires_grad"]
|
||||
# del kw["requires_grad"]
|
||||
# if not has_dtype and "dtype" in kw:
|
||||
# dtype = kw["dtype"]
|
||||
# del kw["dtype"]
|
||||
# if "device" in kw:
|
||||
# del kw["device"]
|
||||
# if 'pin_memory' in kw:
|
||||
# del kw['pin_memory']
|
||||
# args, kw = handle_dtype(args, kw, dtype)
|
||||
# ret = func(*args, **kw)
|
||||
# if isinstance(ret, jt.Var):
|
||||
# if requires_grad is not None:
|
||||
# ret.requires_grad = requires_grad
|
||||
# if dtype is not None:
|
||||
# ret.astype(dtype)
|
||||
# return ret
|
||||
# return inner
|
||||
|
||||
|
||||
# import inspect
|
||||
# _wrapper_keys = set(["shape", "start", "size"])
|
||||
# _wrapper_keys.add("x")
|
||||
# for k,v in list(globals().items()):
|
||||
# if callable(v) and not isinstance(v, type):
|
||||
# try:
|
||||
# spec = inspect.getfullargspec(v)
|
||||
# args_name = spec[0]
|
||||
# if len(args_name) and args_name[0] in _wrapper_keys:
|
||||
# globals()[k] = wrapper(v)
|
||||
# elif spec.varargs in _wrapper_keys:
|
||||
# globals()[k] = wrapper(v)
|
||||
# except:
|
||||
# pass
|
||||
|
||||
# def empty(*size, dtype=jt.float32, device=None, requires_grad=False):
|
||||
# if len(size) == 1 and not isinstance(size[0], org_int):
|
||||
# size = size[0]
|
||||
# return jt.empty(size, dtype)
|
||||
|
||||
# Tensor = Var
|
||||
|
||||
# Tensor.backward = lambda x: jtorch_core.backward(x)
|
||||
# Tensor.grad = property(grad_get, grad_set, grad_del)
|
||||
# Tensor.retains_grad = property(retain_grad_get, retain_grad_set)
|
||||
# def retain_grad(x:Tensor, value:bool=True):
|
||||
# x.retains_grad = value
|
||||
# return value
|
||||
# Tensor.retain_grad = retain_grad
|
||||
|
||||
# Tensor.dim = lambda self: self.ndim
|
||||
# Tensor.ndimension = lambda self: self.ndim
|
||||
# Tensor.nelement = lambda self: self.numel()
|
||||
# Tensor.cuda = lambda self: self
|
||||
# def device_get(x:Tensor):
|
||||
# return device("cpu") if not jt.has_cuda or not jt.flags.use_cuda else device("cuda")
|
||||
# Tensor.device = property(device_get)
|
||||
|
||||
# def argmax(x: Var, dim=None, keepdim: bool = False):
|
||||
# return jt.argmax(x, dim, keepdim)[0]
|
||||
# Tensor.argmax = argmax
|
||||
|
||||
# def tensor_type(x: Var, dtype=None, **kwargs):
|
||||
# if dtype:
|
||||
# return x.astype(dtype)
|
||||
# else:
|
||||
# return x.dtype
|
||||
# Tensor.type = tensor_type
|
||||
|
||||
# def is_floating_point(x: Var):
|
||||
# return "float" in str(x.dtype)
|
||||
# Tensor.is_floating_point = is_floating_point
|
||||
|
||||
# from . import autograd
|
||||
# from .autograd import *
|
||||
|
||||
# def tensor(data, *, dtype=None, device=None, requires_grad=False, pin_memory=False):
|
||||
# if isinstance(data,list):
|
||||
# data_list = []
|
||||
# check = True
|
||||
# for p in data:
|
||||
# if isinstance(p, Tensor) and p.numel()==1:
|
||||
# data_list.append(p.item())
|
||||
# elif isinstance(p, (org_int,org_float)):
|
||||
# data_list.append(p)
|
||||
# else:
|
||||
# check = False
|
||||
# break
|
||||
# if check:
|
||||
# data = data_list
|
||||
# return wrapper(array)(data, dtype=dtype, device=device, requires_grad=requires_grad, pin_memory=pin_memory)
|
||||
|
||||
# # tensor = wrapper(array)
|
||||
# from_numpy = wrapper(array)
|
||||
# strided = None
|
||||
|
||||
# def mod_zero_grad(self):
|
||||
# for p in self.parameters():
|
||||
# p.grad = None
|
||||
# Module.zero_grad = mod_zero_grad
|
||||
|
||||
# class ModuleMisc:
|
||||
# def parameters(self):
|
||||
# return iter(super().parameters())
|
||||
|
||||
# def load_state_dict(self, state_dict, strict=False):
|
||||
# return super().load_state_dict(state_dict)
|
||||
|
||||
# def to(self, device=None,dtype=None):
|
||||
# ''' do nothing but return its self'''
|
||||
# return self
|
||||
# def register_parameter(self,name,data):
|
||||
# self.name = data
|
||||
|
||||
# def buffers(self):
|
||||
# for _, buf in self.named_buffers():
|
||||
# yield buf
|
||||
|
||||
|
||||
# def make_module(cls):
|
||||
# class TMod(ModuleMisc, cls):
|
||||
# def __init__(self, *args, **kw):
|
||||
# dtype = None
|
||||
# if "dtype" in kw:
|
||||
# dtype = kw["dtype"]
|
||||
# del kw["dtype"]
|
||||
# self._dtype = dtype
|
||||
# with jt.flag_scope(th_mode=0):
|
||||
# if "device" in kw:
|
||||
# del kw["device"]
|
||||
# super().__init__(*args, **kw)
|
||||
# for k,v in self.__dict__.items():
|
||||
# if not k.startswith("_") and isinstance(v, Var) \
|
||||
# and v.requires_grad:
|
||||
# v.retain_grad()
|
||||
# if dtype is not None and isinstance(v, Var):
|
||||
# v.assign(v.cast(dtype))
|
||||
# def __call__(self, *args, **kw):
|
||||
# args, kw = handle_dtype(args, kw, self._dtype)
|
||||
# # if forward is override by user, call forward
|
||||
# if self.__class__.forward is not TMod.forward:
|
||||
# return self.forward(*args, **kw)
|
||||
# return self.execute(*args, **kw)
|
||||
# def forward(self, *args, **kw):
|
||||
# args, kw = handle_dtype(args, kw, self._dtype)
|
||||
# return self.execute(*args, **kw)
|
||||
|
||||
# @property
|
||||
# def training(self):
|
||||
# if not hasattr(self, "is_train"):
|
||||
# self.is_train = True
|
||||
# return self.is_train
|
||||
# @training.setter
|
||||
# def training(self, value):
|
||||
# self.is_train = value
|
||||
|
||||
# TMod.__name__ = cls.__name__
|
||||
# return TMod
|
||||
|
||||
# import jtorch.cuda
|
||||
# import jtorch.nn
|
||||
# from jtorch.nn import Module, Parameter
|
||||
# import jtorch.optim
|
||||
|
||||
# from jtorch.utils.dtype import Dtype, get_string_dtype
|
||||
|
||||
# def frombuffer(buffer: bytearray,
|
||||
# *,
|
||||
# dtype: Dtype,
|
||||
# count: int = -1,
|
||||
# offset: int = 0,
|
||||
# requires_grad: bool = True) -> Tensor:
|
||||
# dtype = get_string_dtype(dtype)
|
||||
# tensor = jt.array(np.frombuffer(buffer, dtype, count=count, offset=offset))
|
||||
# if requires_grad and tensor.dtype.is_float():
|
||||
# tensor.requires_grad = True
|
||||
# return tensor
|
||||
|
||||
# def conflict_wrapper(origin_func, new_func):
|
||||
# def wrapper(*args, **kw):
|
||||
# if jt.flags.th_mode:
|
||||
# return new_func(*args, **kw)
|
||||
# else:
|
||||
# return origin_func(*args, **kw)
|
||||
# return wrapper
|
||||
|
||||
# def min(*args, **kw):
|
||||
# dim = None
|
||||
# if len(args) >= 2 and isinstance(args[1], org_int):
|
||||
# dim = args[1]
|
||||
# elif "dim" in kw and isinstance(kw["dim"], org_int):
|
||||
# dim = kw["dim"]
|
||||
# if dim is not None:
|
||||
# k, v = jt.argmin(*args, **kw)
|
||||
# return v, k
|
||||
# elif len(args) == 2 and isinstance(args[1], jt.Var):
|
||||
# return jt.minimum(args[0], args[1])
|
||||
# else:
|
||||
# return jt.min(*args, **kw)
|
||||
# Tensor.min = conflict_wrapper(jt.min, min)
|
||||
|
||||
# def max(*args, **kw):
|
||||
# dim = None
|
||||
# if "dim" in kw:
|
||||
# x = kw["dim"]
|
||||
# if len(args) >= 2 and isinstance(args[1], org_int):
|
||||
# dim = args[1]
|
||||
# elif "dim" in kw and isinstance(kw["dim"], org_int):
|
||||
# dim = kw["dim"]
|
||||
# if dim is not None:
|
||||
# k, v = jt.argmax(*args, **kw)
|
||||
# return v, k
|
||||
# elif len(args) == 2 and isinstance(args[1], jt.Var):
|
||||
# return jt.maximum(args[0], args[1])
|
||||
# else:
|
||||
# return jt.max(*args, **kw)
|
||||
# Tensor.max = conflict_wrapper(jt.max, max)
|
||||
|
||||
# def argsort(*args, **kw):
|
||||
# k, v = jt.argsort(*args, **kw)
|
||||
# return k
|
||||
# Tensor.argsort = conflict_wrapper(jt.argsort, argsort)
|
||||
|
||||
# LongTensor = jt.int64
|
||||
# FloatTensor = jt.float
|
||||
# HalfTensor = jt.float16
|
||||
# BoolTensor = jt.bool
|
||||
# IntTensor = jt.int32
|
||||
|
||||
# class JDType:
|
||||
# def __init__(self, func, str):
|
||||
# self.func = func
|
||||
# self.str = str
|
||||
# self.__name__ = str.split(".")[-1]
|
||||
# def __call__(self, *args, **kw):
|
||||
# return self.func(*args, **kw)
|
||||
# def __str__(self):
|
||||
# return self.str
|
||||
# def is_floating_point(self):
|
||||
# return "float" in str(self.str)
|
||||
|
||||
# int8 = JDType(jt.int8, "torch.int8")
|
||||
# int16 = JDType(jt.int16, "torch.int16")
|
||||
# int = int32 = JDType(jt.int32, "torch.int32")
|
||||
# long = int64 = JDType(jt.int64, "torch.int64")
|
||||
|
||||
# half = float16 = JDType(jt.float16, "torch.float16")
|
||||
# float = float32 = JDType(jt.float32, "torch.float32")
|
||||
# double = float64 = JDType(jt.float64, "torch.float64")
|
||||
# bfloat16 = "bfloat16" # TODO
|
||||
# complex64 = "complex64" # TODO
|
||||
# complex128 = "complex128" # TODO
|
||||
# def get_JDtype(dtype):
|
||||
# if dtype=='float32' or dtype == jt.float32:
|
||||
# return float32
|
||||
# elif dtype=='float64' or dtype == jt.float64:
|
||||
# return float64
|
||||
# elif dtype=='float16' or dtype == jt.float16:
|
||||
# return float16
|
||||
# elif dtype=='int32' or dtype == jt.int32:
|
||||
# return int32
|
||||
# elif dtype=='int64' or dtype == jt.int64:
|
||||
# return int64
|
||||
# elif dtype=='int16' or dtype == jt.int16:
|
||||
# return int16
|
||||
# elif dtype=='int8' or dtype == jt.int8:
|
||||
# return int8
|
||||
# else:
|
||||
# raise Exception("dtype {} not supported".format(dtype))
|
||||
|
||||
# def load(path,**kwargs):
|
||||
# def _to_jittor(data):
|
||||
# if isinstance(data,dict):
|
||||
# return {k:_to_jittor(d) for k,d in data.items()}
|
||||
# if isinstance(data,list):
|
||||
# return [_to_jittor(d) for d in data]
|
||||
# if isinstance(data,np.ndarray):
|
||||
# return jt.array(data)
|
||||
# return data
|
||||
# data = jt.load(path)
|
||||
|
||||
# return _to_jittor(data)
|
||||
|
||||
# def is_tensor(x):
|
||||
# return isinstance(x, Tensor)
|
||||
|
||||
# manual_seed = jt.set_global_seed
|
||||
# jt.flags.amp_level = 3
|
||||
# Size = jt.NanoVector
|
||||
|
||||
# class Generator:
|
||||
# def __init__(self,*args,**kw) -> None:
|
||||
# self.seed = None
|
||||
# def manual_seed(self,seed):
|
||||
# self.seed = seed
|
||||
|
||||
|
||||
|
||||
# from . import fx
|
||||
|
||||
|
||||
# _default_type = "float32"
|
||||
|
||||
# def get_default_dtype():
|
||||
# return _default_type
|
||||
# def set_default_dtype(dtype):
|
||||
# global _default_type
|
||||
# _default_type = dtype
|
||||
|
||||
# dtype = JDType
|
||||
|
||||
# def div(x,y,rounding_mode="floor"):
|
||||
# assert rounding_mode == "floor"
|
||||
# z = (x / y)
|
||||
# if rounding_mode == "floor":
|
||||
# z = z.floor()
|
||||
# if x.dtype == "int32" and (isinstance(y,org_int) or y.dtype == "int32"):
|
||||
# z = z.int32()
|
||||
# return z
|
||||
|
||||
|
||||
# def randn(*args,**kw):
|
||||
# wrap_randn = wrapper(jt.randn)
|
||||
# generator = kw.get('generator',None)
|
||||
# kw.pop('generator',None)
|
||||
# if 'layout' in kw:
|
||||
# del kw['layout']
|
||||
# if generator is not None and generator.seed is not None:
|
||||
# jt.set_global_seed(generator.seed)
|
||||
# return wrap_randn(*args,**kw)
|
||||
|
||||
# def rand(*args,**kw):
|
||||
# print("rand")
|
||||
# wrap_rand = wrapper(jt.rand)
|
||||
# generator = kw.get('generator',None)
|
||||
# kw.pop('generator',None)
|
||||
# if 'layout' in kw:
|
||||
# del kw['layout']
|
||||
# if generator is not None and generator.seed is not None:
|
||||
# jt.set_global_seed(generator.seed)
|
||||
# return wrap_rand(*args,**kw)
|
||||
|
||||
|
||||
|
||||
# def set_default_tensor_type(t: type or str):
|
||||
# if isinstance(t, str):
|
||||
# info = t.split(".")
|
||||
# if len(info) == 3 and info[1] == 'cuda':
|
||||
# jt.flags.use_cuda = 1
|
||||
# #TODO: type
|
||||
|
||||
|
||||
# def clamp(x, min=None, max=None):
|
||||
# return jt.clamp(x, min, max)
|
||||
|
||||
|
||||
# def to(x,*args,**kw):
|
||||
# device = None
|
||||
# if len(args) == 1:
|
||||
# device = args[0]
|
||||
# if isinstance(device, jt.NanoString) or callable(device):
|
||||
# return jt.to(x,*args,**kw)
|
||||
# if 'cpu' in str(device):
|
||||
# args = []
|
||||
# device = kw.get("device",None)
|
||||
# if 'cpu' in str(device):
|
||||
# kw.pop('device',None)
|
||||
# print("to cpu")
|
||||
# # print(kw)
|
||||
# return jt.to(x,*args,**kw)
|
||||
# Tensor.to = conflict_wrapper(jt.to, to)
|
||||
|
||||
# mm = wrapper(jt.matmul)
|
||||
|
||||
# def _data_get(x):
|
||||
# return x
|
||||
|
||||
# def _data_set(x, value):
|
||||
# x.assign(value)
|
||||
|
||||
# Tensor.data = property(_data_get, _data_set)
|
||||
# Tensor.layout = None
|
|
@ -0,0 +1,134 @@
|
|||
import jittor as jt
|
||||
from jittor import Var
|
||||
from collections.abc import Sequence, Mapping
|
||||
|
||||
Variable = Var
|
||||
|
||||
class FunctionContext:
|
||||
def save_for_backward(self, *args):
|
||||
self.saved_tensors = args
|
||||
|
||||
class Function:
|
||||
''' Function Module for customized backward operations
|
||||
|
||||
Example 1 (Function can have multiple input and multiple output, and user
|
||||
can store value for backward computation)::
|
||||
|
||||
import jtorch
|
||||
from jtorch import Function
|
||||
|
||||
class MyFunc(Function):
|
||||
@staticmethod
|
||||
def forward(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
return x*y, x/y
|
||||
|
||||
@staticmethod
|
||||
def backward(self, grad0, grad1):
|
||||
return grad0 * self.y, grad1 * self.x
|
||||
|
||||
a = jtorch.array(3.0)
|
||||
a.requires_grad = True
|
||||
b = jtorch.array(4.0)
|
||||
b.requires_grad = True
|
||||
func = MyFunc.apply
|
||||
c,d = func(a, b)
|
||||
(c+d*3).backward()
|
||||
assert a.grad.data == 4
|
||||
assert b.grad.data == 9
|
||||
|
||||
Example 2(Function can return None for no gradiant, and gradiant
|
||||
can also be None)::
|
||||
|
||||
import jtorch
|
||||
from jtorch import Function
|
||||
|
||||
class MyFunc(Function):
|
||||
@staticmethod
|
||||
def forward(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
return x*y, x/y
|
||||
|
||||
@staticmethod
|
||||
def backward(self, grad0, grad1):
|
||||
assert grad1 is None
|
||||
return grad0 * self.y, None
|
||||
a = jt.array(3.0)
|
||||
a.requires_grad = True
|
||||
b = jt.array(4.0)
|
||||
b.requires_grad = True
|
||||
func = MyFunc.apply
|
||||
c,d = func(a, b)
|
||||
d.stop_grad()
|
||||
da, db = jt.grad(c+d*3, [a, b])
|
||||
assert da.data == 4
|
||||
assert db.data == 0
|
||||
|
||||
'''
|
||||
def __call__(self, *args):
|
||||
backup = args
|
||||
args = list(args)
|
||||
taped_inputs = []
|
||||
taped_outputs = []
|
||||
input_mask = [-1] * len(args)
|
||||
for i,v in enumerate(args):
|
||||
if isinstance(v, Var):
|
||||
if v.is_stop_grad():
|
||||
# -2 in input_mask represents it is stop_grad
|
||||
input_mask[i] = -2
|
||||
continue
|
||||
v = v.tape()
|
||||
input_mask[i] = len(taped_inputs)
|
||||
args[i] = v
|
||||
taped_inputs.append(v)
|
||||
ctx = FunctionContext()
|
||||
ori_res = self.forward(ctx, *args)
|
||||
# ori_res = self.execute(*args)
|
||||
if not isinstance(ori_res, Sequence):
|
||||
res = [ori_res]
|
||||
else:
|
||||
res = list(ori_res)
|
||||
output_mask = [-1] * len(res)
|
||||
for i,v in enumerate(res):
|
||||
if isinstance(v, Var):
|
||||
v = v.tape()
|
||||
output_mask[i] = len(taped_outputs)
|
||||
res[i] = v
|
||||
taped_outputs.append(v)
|
||||
ctx.input_mask = input_mask
|
||||
ctx.output_mask = output_mask
|
||||
# tape output and input together so
|
||||
# backward treat them as one operator
|
||||
jt.tape_together(taped_inputs, taped_outputs,
|
||||
lambda *args: self._grad(ctx, self, *args))
|
||||
if isinstance(ori_res, Sequence):
|
||||
return res
|
||||
else:
|
||||
return res[0]
|
||||
|
||||
@staticmethod
|
||||
def _grad(ctx, func, *args):
|
||||
new_args = ( (args[i] if i>=0 else None) for i in ctx.output_mask )
|
||||
ret = func.backward(ctx, *new_args)
|
||||
if not isinstance(ret, Sequence):
|
||||
ret = (ret,)
|
||||
new_ret = []
|
||||
for i, r in enumerate(ret):
|
||||
j = ctx.input_mask[i]
|
||||
if j<0:
|
||||
# -2 in input_mask represents it is stop_grad
|
||||
assert r is None or j==-2, f"{type(self)}'s {i}-th returned grad should be None, "\
|
||||
"because the input value is not jittor variable."
|
||||
else:
|
||||
new_ret.append(r)
|
||||
return new_ret
|
||||
|
||||
def dfs(self, parents, k, callback, callback_leave=None):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def apply(cls, *args, **kw):
|
||||
func = cls()
|
||||
return func(*args, **kw)
|
|
@ -0,0 +1,39 @@
|
|||
import jittor as jt
|
||||
import jittor_utils
|
||||
import glob
|
||||
import os
|
||||
from jittor import pyjt_compiler
|
||||
import sys
|
||||
from jittor_utils import lock
|
||||
|
||||
|
||||
jtorch_path = os.path.dirname(__file__)
|
||||
cache_path = os.path.join(jt.compiler.cache_path, "jtorch")
|
||||
# os.makedirs(cache_path, exist_ok=True)
|
||||
os.makedirs(os.path.join(cache_path, "gen"), exist_ok=True)
|
||||
|
||||
with lock.lock_scope():
|
||||
pyjt_gen_src = pyjt_compiler.compile(cache_path, jtorch_path)
|
||||
|
||||
ext_args = 'c[cu]' if jt.has_cuda else 'cc'
|
||||
files = glob.glob(jtorch_path+"/src/**/*."+ext_args, recursive=True)
|
||||
files += pyjt_gen_src
|
||||
cc_flags = " -I\""+os.path.join(jtorch_path, "src")+"\" "
|
||||
if os.environ.get("use_data_o", "1") == "1":
|
||||
files += glob.glob(jtorch_path+"/src/**/*.o", recursive=True)
|
||||
files = [f for f in files if "__data__" not in f]
|
||||
|
||||
|
||||
with lock.lock_scope():
|
||||
jt.compiler.compile(
|
||||
jt.compiler.cc_path,
|
||||
jt.compiler.cc_flags+jt.compiler.opt_flags+ cc_flags,
|
||||
files,
|
||||
"jtorch_core"+jt.compiler.extension_suffix,
|
||||
obj_dirname="jtorch_objs")
|
||||
|
||||
|
||||
with jittor_utils.import_scope(jt.compiler.import_flags):
|
||||
import jtorch_core as core
|
||||
|
||||
jt.flags.th_mode = 1
|
|
@ -0,0 +1,64 @@
|
|||
import jittor as jt
|
||||
import jtorch
|
||||
|
||||
def is_available():
|
||||
return jt.has_cuda
|
||||
|
||||
def device_count():
|
||||
return int(jt.has_cuda)
|
||||
|
||||
def set_device(device=None):
|
||||
pass
|
||||
|
||||
def get_rng_state(device=None):
|
||||
pass
|
||||
|
||||
def current_device():
|
||||
return jtorch.device("cuda")
|
||||
|
||||
def mem_get_info(i):
|
||||
return ("75GB",)
|
||||
|
||||
|
||||
class Generator:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def set_state(self, state):
|
||||
self.state = state
|
||||
|
||||
default_generators = [Generator()]
|
||||
_lazy_call = lambda func: func()
|
||||
device = None
|
||||
|
||||
LongTensor = jt.int64
|
||||
FloatTensor = jt.float
|
||||
HalfTensor = jt.float16
|
||||
BoolTensor = jt.bool
|
||||
|
||||
manual_seed = jt.set_global_seed
|
||||
manual_seed_all = jt.set_global_seed
|
||||
|
||||
def synchronize():
|
||||
jt.sync_all(True)
|
||||
|
||||
class Event:
|
||||
pass
|
||||
|
||||
class Stream:
|
||||
pass
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .gradscaler import GradScaler
|
||||
|
||||
class autocast:
|
||||
def __init__(self,**kwargs):
|
||||
pass
|
||||
|
||||
def __enter__(self,):
|
||||
pass
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any):
|
||||
pass
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
import datetime
|
||||
from enum import Enum
|
||||
import jittor as jt
|
||||
|
||||
|
||||
class DistributedDataParallel:
|
||||
def __new__(cls, model):
|
||||
return model
|
||||
|
||||
def is_initialized():
|
||||
return True
|
||||
|
||||
def get_rank(group=None):
|
||||
return 0
|
||||
|
||||
def get_world_size(group=None):
|
||||
return 1
|
||||
|
||||
def get_backend(group=None):
|
||||
return "nccl"
|
||||
|
||||
def new_group(ranks=None, timeout=datetime.timedelta(seconds=1800), backend=None, pg_options=None):
|
||||
return 1
|
||||
|
||||
def barrier():
|
||||
pass
|
||||
|
||||
def is_available():
|
||||
return True
|
||||
|
||||
def is_built():
|
||||
return True
|
||||
|
||||
class ReduceOp:
|
||||
SUM = 0
|
||||
|
||||
class GroupMember:
|
||||
WORLD = 0
|
||||
|
||||
class ProcessGroup:
|
||||
pass
|
||||
|
||||
class Join:
|
||||
pass
|
||||
|
||||
dist_backend = Enum("dist_backend", ("GLOO", "MPI", "NCCL"))
|
||||
_backend = dist_backend.NCCL
|
||||
|
||||
def is_mpi_available():
|
||||
return jt.in_mpi
|
||||
|
||||
def DistributedDataParallel(model, *args, **kw):
|
||||
return model
|
|
@ -0,0 +1,15 @@
|
|||
import jittor as jt
|
||||
|
||||
class RelaxedBernoulli:
|
||||
def __init__(self, temperature, probs=None, logits=None):
|
||||
self.temperature = temperature
|
||||
self.probs = probs
|
||||
self.logits = logits
|
||||
|
||||
def rsample(self):
|
||||
noise = jt.rand_like(self.logits)
|
||||
eps = 1e-20
|
||||
noise = jt.clamp(noise, eps, 1.0 - eps)
|
||||
logit_noise = jt.log(noise) - jt.log(1 - noise)
|
||||
sample = (self.logits + logit_noise) / self.temperature
|
||||
return jt.sigmoid(sample)
|
|
@ -0,0 +1,5 @@
|
|||
#TODO: Implement FFT and IFFT
|
||||
fftn = None
|
||||
fftshift = None
|
||||
ifftn = None
|
||||
ifftshift = None
|
|
@ -0,0 +1,2 @@
|
|||
class Proxy:
|
||||
pass
|
|
@ -0,0 +1,519 @@
|
|||
from collections import defaultdict, abc
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Tuple, cast
|
||||
import inspect
|
||||
import warnings
|
||||
|
||||
import jittor as jt
|
||||
# import torch
|
||||
|
||||
def _refresh_per_optimizer_state():
|
||||
return {}
|
||||
|
||||
|
||||
class GradScaler:
|
||||
_scale: Optional[jt.Var]
|
||||
_grows_tracker: Optional[jt.Var]
|
||||
_per_optimizer_states: Dict[int, Dict[str, Any]]
|
||||
"""
|
||||
An instance ``scaler`` of :class:`GradScaler` helps perform the steps of gradient scaling
|
||||
conveniently.
|
||||
|
||||
* ``scaler.scale(loss)`` multiplies a given loss by ``scaler``'s current scale factor.
|
||||
* ``scaler.step(optimizer)`` safely unscales gradients and calls ``optimizer.step()``.
|
||||
* ``scaler.update()`` updates ``scaler``'s scale factor.
|
||||
|
||||
Example::
|
||||
|
||||
# Creates a GradScaler once at the beginning of training.
|
||||
scaler = GradScaler()
|
||||
|
||||
for epoch in epochs:
|
||||
for input, target in data:
|
||||
optimizer.zero_grad()
|
||||
output = model(input)
|
||||
loss = loss_fn(output, target)
|
||||
|
||||
# Scales loss. Calls backward() on scaled loss to create scaled gradients.
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
# scaler.step() first unscales gradients of the optimizer's params.
|
||||
# If gradients don't contain infs/NaNs, optimizer.step() is then called,
|
||||
# otherwise, optimizer.step() is skipped.
|
||||
scaler.step(optimizer)
|
||||
|
||||
# Updates the scale for next iteration.
|
||||
scaler.update()
|
||||
|
||||
See the :ref:`Automatic Mixed Precision examples<amp-examples>` for usage
|
||||
(along with autocasting) in more complex cases like gradient clipping, gradient accumulation, gradient penalty,
|
||||
and multiple losses/optimizers.
|
||||
|
||||
``scaler`` dynamically estimates the scale factor each iteration. To minimize gradient underflow,
|
||||
a large scale factor should be used. However, ``float16`` values can "overflow" (become inf or NaN) if
|
||||
the scale factor is too large. Therefore, the optimal scale factor is the largest factor that can be used
|
||||
without incurring inf or NaN gradient values.
|
||||
``scaler`` approximates the optimal scale factor over time by checking the gradients for infs and NaNs during every
|
||||
``scaler.step(optimizer)`` (or optional separate ``scaler.unscale_(optimizer)``, see :meth:`unscale_`).
|
||||
|
||||
* If infs/NaNs are found, ``scaler.step(optimizer)`` skips the underlying ``optimizer.step()`` (so the params
|
||||
themselves remain uncorrupted) and ``update()`` multiplies the scale by ``backoff_factor``.
|
||||
|
||||
* If no infs/NaNs are found, ``scaler.step(optimizer)`` runs the underlying ``optimizer.step()`` as usual.
|
||||
If ``growth_interval`` unskipped iterations occur consecutively, ``update()`` multiplies the scale by
|
||||
``growth_factor``.
|
||||
|
||||
The scale factor often causes infs/NaNs to appear in gradients for the first few iterations as its
|
||||
value calibrates. ``scaler.step`` will skip the underlying ``optimizer.step()`` for these
|
||||
iterations. After that, step skipping should occur rarely (once every few hundred or thousand iterations).
|
||||
|
||||
Args:
|
||||
init_scale (float, optional, default=2.**16): Initial scale factor.
|
||||
growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during
|
||||
:meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations.
|
||||
backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during
|
||||
:meth:`update` if inf/NaN gradients occur in an iteration.
|
||||
growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients
|
||||
that must occur for the scale to be multiplied by ``growth_factor``.
|
||||
enabled (bool, optional): If ``False``, disables gradient scaling. :meth:`step` simply
|
||||
invokes the underlying ``optimizer.step()``, and other methods become no-ops.
|
||||
Default: ``True``
|
||||
"""
|
||||
def __init__(self,
|
||||
init_scale=2.**16,
|
||||
growth_factor=2.0,
|
||||
backoff_factor=0.5,
|
||||
growth_interval=2000,
|
||||
enabled=True):
|
||||
self._enabled = enabled
|
||||
|
||||
if self._enabled:
|
||||
assert growth_factor > 1.0, "The growth factor must be > 1.0."
|
||||
assert backoff_factor < 1.0, "The backoff factor must be < 1.0."
|
||||
|
||||
self._init_scale = init_scale
|
||||
# self._scale will be lazily initialized during the first call to scale()
|
||||
self._scale = None
|
||||
self._growth_factor = growth_factor
|
||||
self._backoff_factor = backoff_factor
|
||||
self._growth_interval = growth_interval
|
||||
self._init_growth_tracker = 0
|
||||
# self._growth_tracker will be lazily initialized during the first call to scale()
|
||||
self._growth_tracker = None
|
||||
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
|
||||
|
||||
def _check_scale_growth_tracker(self, funcname) -> Tuple[jt.Var, jt.Var]:
|
||||
fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration."
|
||||
assert self._scale is not None, "Attempted {} but _scale is None. ".format(funcname) + fix
|
||||
assert self._growth_tracker is not None, "Attempted {} but _growth_tracker is None. ".format(funcname) + fix
|
||||
return (self._scale, self._growth_tracker)
|
||||
|
||||
def _lazy_init_scale_growth_tracker(self):
|
||||
assert self._growth_tracker is None, "_growth_tracker initialized before _scale"
|
||||
self._scale = self._init_scale
|
||||
self._growth_tracker = self._init_growth_tracker
|
||||
|
||||
def scale(self, outputs):
|
||||
"""
|
||||
Multiplies ('scales') a tensor or list of tensors by the scale factor.
|
||||
|
||||
Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned
|
||||
unmodified.
|
||||
|
||||
Args:
|
||||
outputs (Tensor or iterable of Tensors): Outputs to scale.
|
||||
"""
|
||||
if not self._enabled:
|
||||
return outputs
|
||||
|
||||
|
||||
# Short-circuit for the common case.
|
||||
if isinstance(outputs, jt.Var):
|
||||
assert jt.flags.use_cuda == 1
|
||||
if self._scale is None:
|
||||
self._lazy_init_scale_growth_tracker()
|
||||
assert self._scale is not None
|
||||
return outputs * self._scale
|
||||
|
||||
def apply_scale(val):
|
||||
if isinstance(val, jt.Var):
|
||||
assert jt.flags.use_cuda == 1
|
||||
if self._scale is None:
|
||||
self._lazy_init_scale_growth_tracker()
|
||||
assert self._scale is not None
|
||||
return val * self._scale
|
||||
elif isinstance(val, abc.Iterable):
|
||||
iterable = map(apply_scale, val)
|
||||
if isinstance(val, (list, tuple)):
|
||||
return type(val)(iterable)
|
||||
else:
|
||||
return iterable
|
||||
else:
|
||||
raise ValueError("outputs must be a Tensor or an iterable of Tensors")
|
||||
|
||||
return apply_scale(outputs)
|
||||
|
||||
def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16):
|
||||
with jt.no_grad():
|
||||
optimizer.pre_step()
|
||||
for group in optimizer.param_groups:
|
||||
for to_unscale in group["grads"]:
|
||||
if to_unscale is None or isinstance(to_unscale,(int,float)):
|
||||
continue
|
||||
if (not allow_fp16) and str(to_unscale.dtype) == "float16":
|
||||
raise ValueError("Attempting to unscale FP16 gradients.")
|
||||
|
||||
if not (to_unscale.isinf().any()):
|
||||
if inv_scale != 1.0:
|
||||
to_unscale.update(to_unscale*inv_scale)
|
||||
else:
|
||||
found_inf = 1.0
|
||||
|
||||
return found_inf
|
||||
|
||||
def unscale_(self, optimizer):
|
||||
"""
|
||||
Divides ("unscales") the optimizer's gradient tensors by the scale factor.
|
||||
|
||||
:meth:`unscale_` is optional, serving cases where you need to
|
||||
:ref:`modify or inspect gradients<working-with-unscaled-gradients>`
|
||||
between the backward pass(es) and :meth:`step`.
|
||||
If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`.
|
||||
|
||||
Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients::
|
||||
|
||||
...
|
||||
scaler.scale(loss).backward()
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
|
||||
Args:
|
||||
optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled.
|
||||
|
||||
.. note::
|
||||
:meth:`unscale_` does not incur a CPU-GPU sync.
|
||||
|
||||
.. warning::
|
||||
:meth:`unscale_` should only be called once per optimizer per :meth:`step` call,
|
||||
and only after all gradients for that optimizer's assigned parameters have been accumulated.
|
||||
Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError.
|
||||
|
||||
.. warning::
|
||||
:meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute.
|
||||
"""
|
||||
if not self._enabled:
|
||||
return
|
||||
|
||||
self._check_scale_growth_tracker("unscale_")
|
||||
|
||||
optimizer_state = self._per_optimizer_states[id(optimizer)]
|
||||
|
||||
if hasattr(optimizer,"get_find_inf"):
|
||||
return
|
||||
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
|
||||
assert self._scale is not None
|
||||
inv_scale = 1.0 / self._scale
|
||||
found_inf = 0.0
|
||||
optimizer_state["found_inf_per_device"] = self._unscale_grads_(optimizer, inv_scale, found_inf, False)
|
||||
|
||||
|
||||
def step(self, optimizer, *args, **kwargs):
|
||||
"""
|
||||
:meth:`step` carries out the following two operations:
|
||||
|
||||
1. Internally invokes ``unscale_(optimizer)`` (unless :meth:`unscale_` was explicitly called for ``optimizer``
|
||||
earlier in the iteration). As part of the :meth:`unscale_`, gradients are checked for infs/NaNs.
|
||||
2. If no inf/NaN gradients are found, invokes ``optimizer.step()`` using the unscaled
|
||||
gradients. Otherwise, ``optimizer.step()`` is skipped to avoid corrupting the params.
|
||||
|
||||
``*args`` and ``**kwargs`` are forwarded to ``optimizer.step()``.
|
||||
|
||||
Returns the return value of ``optimizer.step(*args, **kwargs)``.
|
||||
|
||||
Args:
|
||||
optimizer (torch.optim.Optimizer): Optimizer that applies the gradients.
|
||||
args: Any arguments.
|
||||
kwargs: Any keyword arguments.
|
||||
|
||||
.. warning::
|
||||
Closure use is not currently supported.
|
||||
"""
|
||||
if (not self._enabled):
|
||||
return optimizer.step(*args, **kwargs)
|
||||
|
||||
if "closure" in kwargs:
|
||||
raise RuntimeError("Closure use is not currently supported if GradScaler is enabled.")
|
||||
|
||||
self._check_scale_growth_tracker("step")
|
||||
|
||||
optimizer_state = self._per_optimizer_states[id(optimizer)]
|
||||
retval = None
|
||||
|
||||
if (hasattr(optimizer, "_step_supports_amp_scaling") and optimizer._step_supports_amp_scaling):
|
||||
# This optimizer has customized scale-handling logic, so we can call optimizer.step() directly.
|
||||
# The contract with custom optimizers is that their step() should accept an additional,
|
||||
# optional grad_scaler kwarg. We append self to the kwargs so the custom optimizer has full information:
|
||||
# it can query its own state, invoke unscale_ on itself, etc
|
||||
# The contract above is being deprecated to avoid introducing `grad_scaler: GradScaler` argument
|
||||
# to `Optimizer.step`. The new behavior is going to add two Tensor attributes of `grad_scale`
|
||||
# and `found_inf` to the passed optimizer so that the optimizer can utilize those
|
||||
# to skip the parameter updates or unscale gradients before updating parameters in
|
||||
# the fused kernel, e.g. `FusedAdamMathFunctor`.
|
||||
# In this behavior, `GradScaler._check_inf_per_device` is called if `OptState.READY`,
|
||||
# while the method is expected to be called by users side, i.e. their optimizers.
|
||||
kwargs_ = kwargs
|
||||
has_grad_scaler_kwarg = "grad_scaler" in inspect.signature(optimizer.step).parameters
|
||||
if has_grad_scaler_kwarg:
|
||||
warnings.warn(
|
||||
"GradScaler is going to stop passing itself as a keyword argument to the passed "
|
||||
"optimizer. In the near future GradScaler registers `grad_scale: Tensor` and "
|
||||
"`found_inf: Tensor` to the passed optimizer and let the optimizer use them directly.",
|
||||
FutureWarning)
|
||||
kwargs_.update({"grad_scaler": self})
|
||||
else:
|
||||
if optimizer_state["stage"] is OptState.READY:
|
||||
self._check_inf_per_device(optimizer)
|
||||
scaler = self._get_scale_async()
|
||||
found_inf = cast(
|
||||
jt.Var,
|
||||
sum([
|
||||
t for t in optimizer_state["found_inf_per_device"].values()
|
||||
])
|
||||
)
|
||||
optimizer.grad_scale = None if optimizer_state["stage"] == OptState.UNSCALED else scaler
|
||||
optimizer.found_inf = found_inf
|
||||
retval = optimizer.step(*args, **kwargs_)
|
||||
optimizer_state["stage"] = OptState.STEPPED
|
||||
if not has_grad_scaler_kwarg:
|
||||
del optimizer.grad_scale
|
||||
del optimizer.found_inf
|
||||
return retval
|
||||
|
||||
if hasattr(optimizer,"get_find_inf"):
|
||||
optimizer.set_grad_scale(self._scale)
|
||||
optimizer.step()
|
||||
optimizer_state["found_inf_per_device"] = optimizer.get_find_inf()
|
||||
return
|
||||
|
||||
retval = None
|
||||
if not optimizer_state["found_inf_per_device"]:
|
||||
retval = optimizer.step(*args, **kwargs)
|
||||
else:
|
||||
optimizer.post_step()
|
||||
|
||||
return retval
|
||||
|
||||
|
||||
def update(self, new_scale=None):
|
||||
"""
|
||||
Updates the scale factor.
|
||||
|
||||
If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
|
||||
to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
|
||||
the scale is multiplied by ``growth_factor`` to increase it.
|
||||
|
||||
Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
|
||||
used directly, it's used to fill GradScaler's internal scale tensor. So if
|
||||
``new_scale`` was a tensor, later in-place changes to that tensor will not further
|
||||
affect the scale GradScaler uses internally.)
|
||||
|
||||
Args:
|
||||
new_scale (float or :class:`torch.cuda.FloatTensor`, optional, default=None): New scale factor.
|
||||
|
||||
.. warning::
|
||||
:meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
|
||||
been invoked for all optimizers used this iteration.
|
||||
"""
|
||||
if not self._enabled:
|
||||
return
|
||||
|
||||
_scale, _growth_tracker = self._check_scale_growth_tracker("update")
|
||||
|
||||
if new_scale is not None:
|
||||
# Accept a new user-defined scale.
|
||||
if isinstance(new_scale, float):
|
||||
self._scale.fill_(new_scale) # type: ignore[union-attr]
|
||||
else:
|
||||
reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False."
|
||||
assert isinstance(new_scale, torch.cuda.FloatTensor), reason # type: ignore[attr-defined]
|
||||
assert new_scale.numel() == 1, reason
|
||||
assert new_scale.requires_grad is False, reason
|
||||
self._scale.copy_(new_scale) # type: ignore[union-attr]
|
||||
else:
|
||||
# Consume shared inf/nan data collected from optimizers to update the scale.
|
||||
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
|
||||
found_infs = [state["found_inf_per_device"]
|
||||
for state in self._per_optimizer_states.values()
|
||||
]
|
||||
|
||||
assert len(found_infs) > 0, "No inf checks were recorded prior to update."
|
||||
|
||||
found_inf_combined = found_infs[0]
|
||||
if len(found_infs) > 1:
|
||||
for i in range(1, len(found_infs)):
|
||||
found_inf_combined += found_infs[i]
|
||||
|
||||
|
||||
current_scale = _scale
|
||||
if found_inf_combined:
|
||||
current_scale *=self._backoff_factor
|
||||
_growth_tracker = 0
|
||||
else:
|
||||
successful = _growth_tracker+1
|
||||
if successful == self._growth_interval:
|
||||
new_scale = current_scale*self._growth_factor
|
||||
if new_scale < 1e9:
|
||||
current_scale = new_scale
|
||||
_growth_tracker = 0
|
||||
else:
|
||||
_growth_tracker = successful
|
||||
|
||||
self._scale, self._growth_tracker = current_scale,_growth_tracker
|
||||
|
||||
# To prepare for next iteration, clear the data collected from optimizers this iteration.
|
||||
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
|
||||
|
||||
def _get_scale_async(self):
|
||||
return self._scale
|
||||
|
||||
def get_scale(self):
|
||||
"""
|
||||
Returns a Python float containing the current scale, or 1.0 if scaling is disabled.
|
||||
|
||||
.. warning::
|
||||
:meth:`get_scale` incurs a CPU-GPU sync.
|
||||
"""
|
||||
if self._enabled:
|
||||
return self._init_scale if self._scale is None else self._get_scale_async()
|
||||
else:
|
||||
return 1.0
|
||||
|
||||
def get_growth_factor(self):
|
||||
r"""
|
||||
Returns a Python float containing the scale growth factor.
|
||||
"""
|
||||
return self._growth_factor
|
||||
|
||||
def set_growth_factor(self, new_factor):
|
||||
r"""
|
||||
Args:
|
||||
new_scale (float): Value to use as the new scale growth factor.
|
||||
"""
|
||||
self._growth_factor = new_factor
|
||||
|
||||
def get_backoff_factor(self):
|
||||
r"""
|
||||
Returns a Python float containing the scale backoff factor.
|
||||
"""
|
||||
return self._backoff_factor
|
||||
|
||||
def set_backoff_factor(self, new_factor):
|
||||
r"""
|
||||
Args:
|
||||
new_scale (float): Value to use as the new scale backoff factor.
|
||||
"""
|
||||
self._backoff_factor = new_factor
|
||||
|
||||
def get_growth_interval(self):
|
||||
r"""
|
||||
Returns a Python int containing the growth interval.
|
||||
"""
|
||||
return self._growth_interval
|
||||
|
||||
def set_growth_interval(self, new_interval):
|
||||
r"""
|
||||
Args:
|
||||
new_interval (int): Value to use as the new growth interval.
|
||||
"""
|
||||
self._growth_interval = new_interval
|
||||
|
||||
def _get_growth_tracker(self):
|
||||
if self._enabled:
|
||||
return self._init_growth_tracker if self._growth_tracker is None else self._growth_tracker.item()
|
||||
else:
|
||||
return 0
|
||||
|
||||
def is_enabled(self):
|
||||
r"""
|
||||
Returns a bool indicating whether this instance is enabled.
|
||||
"""
|
||||
return self._enabled
|
||||
|
||||
def state_dict(self):
|
||||
r"""
|
||||
Returns the state of the scaler as a :class:`dict`. It contains five entries:
|
||||
|
||||
* ``"scale"`` - a Python float containing the current scale
|
||||
* ``"growth_factor"`` - a Python float containing the current growth factor
|
||||
* ``"backoff_factor"`` - a Python float containing the current backoff factor
|
||||
* ``"growth_interval"`` - a Python int containing the current growth interval
|
||||
* ``"_growth_tracker"`` - a Python int containing the number of recent consecutive unskipped steps.
|
||||
|
||||
If this instance is not enabled, returns an empty dict.
|
||||
|
||||
.. note::
|
||||
If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict`
|
||||
should be called after :meth:`update`.
|
||||
"""
|
||||
return {"scale": self.get_scale(),
|
||||
"growth_factor": self._growth_factor,
|
||||
"backoff_factor": self._backoff_factor,
|
||||
"growth_interval": self._growth_interval,
|
||||
"_growth_tracker": self._get_growth_tracker()} if self._enabled else {}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
r"""
|
||||
Loads the scaler state. If this instance is disabled, :meth:`load_state_dict` is a no-op.
|
||||
|
||||
Args:
|
||||
state_dict(dict): scaler state. Should be an object returned from a call to :meth:`state_dict`.
|
||||
"""
|
||||
if not self._enabled:
|
||||
return
|
||||
|
||||
if len(state_dict) == 0:
|
||||
raise RuntimeError("The source state dict is empty, possibly because it was saved "
|
||||
"from a disabled instance of GradScaler.")
|
||||
|
||||
self._init_scale = state_dict["scale"]
|
||||
if self._scale is not None:
|
||||
self._scale.fill_(state_dict["scale"])
|
||||
self._growth_factor = state_dict["growth_factor"]
|
||||
self._backoff_factor = state_dict["backoff_factor"]
|
||||
self._growth_interval = state_dict["growth_interval"]
|
||||
self._init_growth_tracker = state_dict["_growth_tracker"]
|
||||
if self._growth_tracker is not None:
|
||||
self._growth_tracker.fill_(state_dict["_growth_tracker"])
|
||||
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
if self._enabled:
|
||||
assert len(self._per_optimizer_states) == 0, "A GradScaler instance may only be pickled at the beginning "\
|
||||
"of an iteration, or at the end after scaler.update()."
|
||||
# Pickling _scale and _growth_tracker Tensors directly triggers
|
||||
# "warnings.warn("pickle support for Storage will be removed in 1.5..."
|
||||
# so instead, we set the unpickled instance up to reinitialize them lazily.
|
||||
state['_init_scale'] = self.get_scale()
|
||||
state['_init_growth_tracker'] = self._get_growth_tracker()
|
||||
state['_scale'] = None
|
||||
state['_growth_tracker'] = None
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__dict__.update(state)
|
||||
|
||||
def _check_inf_per_device(self, optimizer):
|
||||
_scale, _ = self._check_scale_growth_tracker("_check_inf_per_device")
|
||||
|
||||
dummy_inv_scale = 1.0
|
||||
found_inf = 0.0
|
||||
|
||||
self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = \
|
||||
self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True)
|
||||
|
||||
return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]
|
||||
|
||||
def _found_inf_per_device(self, optimizer):
|
||||
return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]
|
|
@ -0,0 +1,556 @@
|
|||
from collections import defaultdict, abc
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Tuple, cast
|
||||
import inspect
|
||||
import warnings
|
||||
|
||||
import jittor as jt
|
||||
# import torch
|
||||
|
||||
|
||||
__all__ = ["OptState", "GradScaler"]
|
||||
|
||||
|
||||
# Defines default_factory for GradScaler's _per_optimizer_states defaultdict,
|
||||
# as well as associated "enum" values. Prefers defining these at top level because
|
||||
# - Lambdas can't be pickled, so we don't want to supply a lambda as the factory.
|
||||
# - Defining READY, UNSCALED, STEPPED and _refresh_per_optimizer_state within GradScaler
|
||||
# causes a circular reference, which we'd rather avoid.
|
||||
class OptState(Enum):
|
||||
READY = 0
|
||||
UNSCALED = 1
|
||||
STEPPED = 2
|
||||
|
||||
|
||||
def _refresh_per_optimizer_state():
|
||||
return {"stage": OptState.READY, "found_inf_per_device": {}}
|
||||
|
||||
|
||||
class GradScaler:
|
||||
_scale: Optional[jt.Var]
|
||||
_grows_tracker: Optional[jt.Var]
|
||||
_per_optimizer_states: Dict[int, Dict[str, Any]]
|
||||
"""
|
||||
An instance ``scaler`` of :class:`GradScaler` helps perform the steps of gradient scaling
|
||||
conveniently.
|
||||
|
||||
* ``scaler.scale(loss)`` multiplies a given loss by ``scaler``'s current scale factor.
|
||||
* ``scaler.step(optimizer)`` safely unscales gradients and calls ``optimizer.step()``.
|
||||
* ``scaler.update()`` updates ``scaler``'s scale factor.
|
||||
|
||||
Example::
|
||||
|
||||
# Creates a GradScaler once at the beginning of training.
|
||||
scaler = GradScaler()
|
||||
|
||||
for epoch in epochs:
|
||||
for input, target in data:
|
||||
optimizer.zero_grad()
|
||||
output = model(input)
|
||||
loss = loss_fn(output, target)
|
||||
|
||||
# Scales loss. Calls backward() on scaled loss to create scaled gradients.
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
# scaler.step() first unscales gradients of the optimizer's params.
|
||||
# If gradients don't contain infs/NaNs, optimizer.step() is then called,
|
||||
# otherwise, optimizer.step() is skipped.
|
||||
scaler.step(optimizer)
|
||||
|
||||
# Updates the scale for next iteration.
|
||||
scaler.update()
|
||||
|
||||
See the :ref:`Automatic Mixed Precision examples<amp-examples>` for usage
|
||||
(along with autocasting) in more complex cases like gradient clipping, gradient accumulation, gradient penalty,
|
||||
and multiple losses/optimizers.
|
||||
|
||||
``scaler`` dynamically estimates the scale factor each iteration. To minimize gradient underflow,
|
||||
a large scale factor should be used. However, ``float16`` values can "overflow" (become inf or NaN) if
|
||||
the scale factor is too large. Therefore, the optimal scale factor is the largest factor that can be used
|
||||
without incurring inf or NaN gradient values.
|
||||
``scaler`` approximates the optimal scale factor over time by checking the gradients for infs and NaNs during every
|
||||
``scaler.step(optimizer)`` (or optional separate ``scaler.unscale_(optimizer)``, see :meth:`unscale_`).
|
||||
|
||||
* If infs/NaNs are found, ``scaler.step(optimizer)`` skips the underlying ``optimizer.step()`` (so the params
|
||||
themselves remain uncorrupted) and ``update()`` multiplies the scale by ``backoff_factor``.
|
||||
|
||||
* If no infs/NaNs are found, ``scaler.step(optimizer)`` runs the underlying ``optimizer.step()`` as usual.
|
||||
If ``growth_interval`` unskipped iterations occur consecutively, ``update()`` multiplies the scale by
|
||||
``growth_factor``.
|
||||
|
||||
The scale factor often causes infs/NaNs to appear in gradients for the first few iterations as its
|
||||
value calibrates. ``scaler.step`` will skip the underlying ``optimizer.step()`` for these
|
||||
iterations. After that, step skipping should occur rarely (once every few hundred or thousand iterations).
|
||||
|
||||
Args:
|
||||
init_scale (float, optional, default=2.**16): Initial scale factor.
|
||||
growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during
|
||||
:meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations.
|
||||
backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during
|
||||
:meth:`update` if inf/NaN gradients occur in an iteration.
|
||||
growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients
|
||||
that must occur for the scale to be multiplied by ``growth_factor``.
|
||||
enabled (bool, optional): If ``False``, disables gradient scaling. :meth:`step` simply
|
||||
invokes the underlying ``optimizer.step()``, and other methods become no-ops.
|
||||
Default: ``True``
|
||||
"""
|
||||
def __init__(self,
|
||||
init_scale=2.**16,
|
||||
growth_factor=2.0,
|
||||
backoff_factor=0.5,
|
||||
growth_interval=2000,
|
||||
enabled=True):
|
||||
self._enabled = enabled
|
||||
|
||||
if self._enabled:
|
||||
assert growth_factor > 1.0, "The growth factor must be > 1.0."
|
||||
assert backoff_factor < 1.0, "The backoff factor must be < 1.0."
|
||||
|
||||
self._init_scale = init_scale
|
||||
# self._scale will be lazily initialized during the first call to scale()
|
||||
self._scale = None
|
||||
self._growth_factor = growth_factor
|
||||
self._backoff_factor = backoff_factor
|
||||
self._growth_interval = growth_interval
|
||||
self._init_growth_tracker = 0
|
||||
# self._growth_tracker will be lazily initialized during the first call to scale()
|
||||
self._growth_tracker = None
|
||||
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
|
||||
|
||||
def _check_scale_growth_tracker(self, funcname) -> Tuple[jt.Var, jt.Var]:
|
||||
fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration."
|
||||
assert self._scale is not None, "Attempted {} but _scale is None. ".format(funcname) + fix
|
||||
assert self._growth_tracker is not None, "Attempted {} but _growth_tracker is None. ".format(funcname) + fix
|
||||
return (self._scale, self._growth_tracker)
|
||||
|
||||
def _lazy_init_scale_growth_tracker(self):
|
||||
assert self._growth_tracker is None, "_growth_tracker initialized before _scale"
|
||||
self._scale = self._init_scale
|
||||
self._growth_tracker = self._init_growth_tracker
|
||||
|
||||
def scale(self, outputs):
|
||||
"""
|
||||
Multiplies ('scales') a tensor or list of tensors by the scale factor.
|
||||
|
||||
Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned
|
||||
unmodified.
|
||||
|
||||
Args:
|
||||
outputs (Tensor or iterable of Tensors): Outputs to scale.
|
||||
"""
|
||||
print("scale")
|
||||
if not self._enabled:
|
||||
return outputs
|
||||
|
||||
|
||||
# Short-circuit for the common case.
|
||||
if isinstance(outputs, jt.Var):
|
||||
assert jt.flags.use_cuda == 1
|
||||
if self._scale is None:
|
||||
self._lazy_init_scale_growth_tracker()
|
||||
assert self._scale is not None
|
||||
return outputs * self._scale
|
||||
|
||||
def apply_scale(val):
|
||||
if isinstance(val, jt.Var):
|
||||
assert jt.flags.use_cuda == 1
|
||||
if self._scale is None:
|
||||
self._lazy_init_scale_growth_tracker()
|
||||
assert self._scale is not None
|
||||
return val * self._scale
|
||||
elif isinstance(val, abc.Iterable):
|
||||
iterable = map(apply_scale, val)
|
||||
if isinstance(val, (list, tuple)):
|
||||
return type(val)(iterable)
|
||||
else:
|
||||
return iterable
|
||||
else:
|
||||
raise ValueError("outputs must be a Tensor or an iterable of Tensors")
|
||||
|
||||
return apply_scale(outputs)
|
||||
|
||||
def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16):
|
||||
|
||||
# To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype.
|
||||
# There could be hundreds of grads, so we'd like to iterate through them just once.
|
||||
# However, we don't know their devices or dtypes in advance.
|
||||
|
||||
# https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
|
||||
# Google says mypy struggles with defaultdicts type annotations.
|
||||
with jt.no_grad():
|
||||
optimizer.pre_step()
|
||||
for group in optimizer.param_groups:
|
||||
for to_unscale in group["grads"]:
|
||||
if to_unscale is None or isinstance(to_unscale,(int,float)):
|
||||
continue
|
||||
if (not allow_fp16) and str(to_unscale.dtype) == "float16":
|
||||
raise ValueError("Attempting to unscale FP16 gradients.")
|
||||
|
||||
if not (to_unscale.isinf().any()):
|
||||
if inv_scale != 1.0:
|
||||
to_unscale.update(to_unscale*inv_scale)
|
||||
else:
|
||||
found_inf = 1.0
|
||||
|
||||
return found_inf
|
||||
|
||||
def unscale_(self, optimizer):
|
||||
"""
|
||||
Divides ("unscales") the optimizer's gradient tensors by the scale factor.
|
||||
|
||||
:meth:`unscale_` is optional, serving cases where you need to
|
||||
:ref:`modify or inspect gradients<working-with-unscaled-gradients>`
|
||||
between the backward pass(es) and :meth:`step`.
|
||||
If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`.
|
||||
|
||||
Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients::
|
||||
|
||||
...
|
||||
scaler.scale(loss).backward()
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
|
||||
Args:
|
||||
optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled.
|
||||
|
||||
.. note::
|
||||
:meth:`unscale_` does not incur a CPU-GPU sync.
|
||||
|
||||
.. warning::
|
||||
:meth:`unscale_` should only be called once per optimizer per :meth:`step` call,
|
||||
and only after all gradients for that optimizer's assigned parameters have been accumulated.
|
||||
Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError.
|
||||
|
||||
.. warning::
|
||||
:meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute.
|
||||
"""
|
||||
if not self._enabled:
|
||||
return
|
||||
|
||||
self._check_scale_growth_tracker("unscale_")
|
||||
|
||||
optimizer_state = self._per_optimizer_states[id(optimizer)]
|
||||
|
||||
if optimizer_state["stage"] is OptState.UNSCALED:
|
||||
raise RuntimeError("unscale_() has already been called on this optimizer since the last update().")
|
||||
elif optimizer_state["stage"] is OptState.STEPPED:
|
||||
raise RuntimeError("unscale_() is being called after step().")
|
||||
|
||||
|
||||
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
|
||||
assert self._scale is not None
|
||||
inv_scale = 1.0 / self._scale
|
||||
found_inf = 0.0
|
||||
optimizer_state["found_inf_per_device"] = self._unscale_grads_(optimizer, inv_scale, found_inf, False)
|
||||
optimizer_state["stage"] = OptState.UNSCALED
|
||||
|
||||
def _maybe_opt_step(self, optimizer, optimizer_state, *args, **kwargs):
|
||||
retval = None
|
||||
if not optimizer_state["found_inf_per_device"]:
|
||||
retval = optimizer.step(*args, **kwargs)
|
||||
else:
|
||||
optimizer.post_step()
|
||||
|
||||
return retval
|
||||
|
||||
def step(self, optimizer, *args, **kwargs):
|
||||
"""
|
||||
:meth:`step` carries out the following two operations:
|
||||
|
||||
1. Internally invokes ``unscale_(optimizer)`` (unless :meth:`unscale_` was explicitly called for ``optimizer``
|
||||
earlier in the iteration). As part of the :meth:`unscale_`, gradients are checked for infs/NaNs.
|
||||
2. If no inf/NaN gradients are found, invokes ``optimizer.step()`` using the unscaled
|
||||
gradients. Otherwise, ``optimizer.step()`` is skipped to avoid corrupting the params.
|
||||
|
||||
``*args`` and ``**kwargs`` are forwarded to ``optimizer.step()``.
|
||||
|
||||
Returns the return value of ``optimizer.step(*args, **kwargs)``.
|
||||
|
||||
Args:
|
||||
optimizer (torch.optim.Optimizer): Optimizer that applies the gradients.
|
||||
args: Any arguments.
|
||||
kwargs: Any keyword arguments.
|
||||
|
||||
.. warning::
|
||||
Closure use is not currently supported.
|
||||
"""
|
||||
if (not self._enabled):
|
||||
return optimizer.step(*args, **kwargs)
|
||||
|
||||
if "closure" in kwargs:
|
||||
raise RuntimeError("Closure use is not currently supported if GradScaler is enabled.")
|
||||
|
||||
self._check_scale_growth_tracker("step")
|
||||
|
||||
optimizer_state = self._per_optimizer_states[id(optimizer)]
|
||||
|
||||
if optimizer_state["stage"] is OptState.STEPPED:
|
||||
raise RuntimeError("step() has already been called since the last update().")
|
||||
|
||||
retval = None
|
||||
|
||||
if (hasattr(optimizer, "_step_supports_amp_scaling") and optimizer._step_supports_amp_scaling):
|
||||
# This optimizer has customized scale-handling logic, so we can call optimizer.step() directly.
|
||||
# The contract with custom optimizers is that their step() should accept an additional,
|
||||
# optional grad_scaler kwarg. We append self to the kwargs so the custom optimizer has full information:
|
||||
# it can query its own state, invoke unscale_ on itself, etc
|
||||
# The contract above is being deprecated to avoid introducing `grad_scaler: GradScaler` argument
|
||||
# to `Optimizer.step`. The new behavior is going to add two Tensor attributes of `grad_scale`
|
||||
# and `found_inf` to the passed optimizer so that the optimizer can utilize those
|
||||
# to skip the parameter updates or unscale gradients before updating parameters in
|
||||
# the fused kernel, e.g. `FusedAdamMathFunctor`.
|
||||
# In this behavior, `GradScaler._check_inf_per_device` is called if `OptState.READY`,
|
||||
# while the method is expected to be called by users side, i.e. their optimizers.
|
||||
kwargs_ = kwargs
|
||||
has_grad_scaler_kwarg = "grad_scaler" in inspect.signature(optimizer.step).parameters
|
||||
if has_grad_scaler_kwarg:
|
||||
warnings.warn(
|
||||
"GradScaler is going to stop passing itself as a keyword argument to the passed "
|
||||
"optimizer. In the near future GradScaler registers `grad_scale: Tensor` and "
|
||||
"`found_inf: Tensor` to the passed optimizer and let the optimizer use them directly.",
|
||||
FutureWarning)
|
||||
kwargs_.update({"grad_scaler": self})
|
||||
else:
|
||||
if optimizer_state["stage"] is OptState.READY:
|
||||
self._check_inf_per_device(optimizer)
|
||||
scaler = self._get_scale_async()
|
||||
found_inf = cast(
|
||||
jt.Var,
|
||||
sum([
|
||||
t for t in optimizer_state["found_inf_per_device"].values()
|
||||
])
|
||||
)
|
||||
optimizer.grad_scale = None if optimizer_state["stage"] == OptState.UNSCALED else scaler
|
||||
optimizer.found_inf = found_inf
|
||||
retval = optimizer.step(*args, **kwargs_)
|
||||
optimizer_state["stage"] = OptState.STEPPED
|
||||
if not has_grad_scaler_kwarg:
|
||||
del optimizer.grad_scale
|
||||
del optimizer.found_inf
|
||||
return retval
|
||||
|
||||
|
||||
if optimizer_state["stage"] is OptState.READY:
|
||||
self.unscale_(optimizer)
|
||||
|
||||
assert "found_inf_per_device" in optimizer_state, "No inf checks were recorded for this optimizer."
|
||||
|
||||
retval = self._maybe_opt_step(optimizer, optimizer_state, *args, **kwargs)
|
||||
|
||||
optimizer_state["stage"] = OptState.STEPPED
|
||||
|
||||
return retval
|
||||
|
||||
def update(self, new_scale=None):
|
||||
"""
|
||||
Updates the scale factor.
|
||||
|
||||
If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
|
||||
to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
|
||||
the scale is multiplied by ``growth_factor`` to increase it.
|
||||
|
||||
Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
|
||||
used directly, it's used to fill GradScaler's internal scale tensor. So if
|
||||
``new_scale`` was a tensor, later in-place changes to that tensor will not further
|
||||
affect the scale GradScaler uses internally.)
|
||||
|
||||
Args:
|
||||
new_scale (float or :class:`torch.cuda.FloatTensor`, optional, default=None): New scale factor.
|
||||
|
||||
.. warning::
|
||||
:meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
|
||||
been invoked for all optimizers used this iteration.
|
||||
"""
|
||||
if not self._enabled:
|
||||
return
|
||||
|
||||
_scale, _growth_tracker = self._check_scale_growth_tracker("update")
|
||||
|
||||
if new_scale is not None:
|
||||
# Accept a new user-defined scale.
|
||||
if isinstance(new_scale, float):
|
||||
self._scale.fill_(new_scale) # type: ignore[union-attr]
|
||||
else:
|
||||
reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False."
|
||||
assert isinstance(new_scale, torch.cuda.FloatTensor), reason # type: ignore[attr-defined]
|
||||
assert new_scale.numel() == 1, reason
|
||||
assert new_scale.requires_grad is False, reason
|
||||
self._scale.copy_(new_scale) # type: ignore[union-attr]
|
||||
else:
|
||||
# Consume shared inf/nan data collected from optimizers to update the scale.
|
||||
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
|
||||
found_infs = [state["found_inf_per_device"]
|
||||
for state in self._per_optimizer_states.values()
|
||||
]
|
||||
|
||||
assert len(found_infs) > 0, "No inf checks were recorded prior to update."
|
||||
|
||||
found_inf_combined = found_infs[0]
|
||||
if len(found_infs) > 1:
|
||||
for i in range(1, len(found_infs)):
|
||||
found_inf_combined += found_infs[i]
|
||||
|
||||
|
||||
current_scale = _scale
|
||||
if found_inf_combined:
|
||||
current_scale *=self._backoff_factor
|
||||
_growth_tracker = 0
|
||||
else:
|
||||
successful = _growth_tracker+1
|
||||
if successful == self._growth_interval:
|
||||
new_scale = current_scale*self._growth_factor
|
||||
if new_scale < 1e9:
|
||||
current_scale = new_scale
|
||||
_growth_tracker = 0
|
||||
else:
|
||||
_growth_tracker = successful
|
||||
|
||||
self._scale, self._growth_tracker = current_scale,_growth_tracker
|
||||
|
||||
# To prepare for next iteration, clear the data collected from optimizers this iteration.
|
||||
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
|
||||
|
||||
def _get_scale_async(self):
|
||||
return self._scale
|
||||
|
||||
def get_scale(self):
|
||||
"""
|
||||
Returns a Python float containing the current scale, or 1.0 if scaling is disabled.
|
||||
|
||||
.. warning::
|
||||
:meth:`get_scale` incurs a CPU-GPU sync.
|
||||
"""
|
||||
if self._enabled:
|
||||
return self._init_scale if self._scale is None else self._get_scale_async()
|
||||
else:
|
||||
return 1.0
|
||||
|
||||
def get_growth_factor(self):
|
||||
r"""
|
||||
Returns a Python float containing the scale growth factor.
|
||||
"""
|
||||
return self._growth_factor
|
||||
|
||||
def set_growth_factor(self, new_factor):
|
||||
r"""
|
||||
Args:
|
||||
new_scale (float): Value to use as the new scale growth factor.
|
||||
"""
|
||||
self._growth_factor = new_factor
|
||||
|
||||
def get_backoff_factor(self):
|
||||
r"""
|
||||
Returns a Python float containing the scale backoff factor.
|
||||
"""
|
||||
return self._backoff_factor
|
||||
|
||||
def set_backoff_factor(self, new_factor):
|
||||
r"""
|
||||
Args:
|
||||
new_scale (float): Value to use as the new scale backoff factor.
|
||||
"""
|
||||
self._backoff_factor = new_factor
|
||||
|
||||
def get_growth_interval(self):
|
||||
r"""
|
||||
Returns a Python int containing the growth interval.
|
||||
"""
|
||||
return self._growth_interval
|
||||
|
||||
def set_growth_interval(self, new_interval):
|
||||
r"""
|
||||
Args:
|
||||
new_interval (int): Value to use as the new growth interval.
|
||||
"""
|
||||
self._growth_interval = new_interval
|
||||
|
||||
def _get_growth_tracker(self):
|
||||
if self._enabled:
|
||||
return self._init_growth_tracker if self._growth_tracker is None else self._growth_tracker.item()
|
||||
else:
|
||||
return 0
|
||||
|
||||
def is_enabled(self):
|
||||
r"""
|
||||
Returns a bool indicating whether this instance is enabled.
|
||||
"""
|
||||
return self._enabled
|
||||
|
||||
def state_dict(self):
|
||||
r"""
|
||||
Returns the state of the scaler as a :class:`dict`. It contains five entries:
|
||||
|
||||
* ``"scale"`` - a Python float containing the current scale
|
||||
* ``"growth_factor"`` - a Python float containing the current growth factor
|
||||
* ``"backoff_factor"`` - a Python float containing the current backoff factor
|
||||
* ``"growth_interval"`` - a Python int containing the current growth interval
|
||||
* ``"_growth_tracker"`` - a Python int containing the number of recent consecutive unskipped steps.
|
||||
|
||||
If this instance is not enabled, returns an empty dict.
|
||||
|
||||
.. note::
|
||||
If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict`
|
||||
should be called after :meth:`update`.
|
||||
"""
|
||||
return {"scale": self.get_scale(),
|
||||
"growth_factor": self._growth_factor,
|
||||
"backoff_factor": self._backoff_factor,
|
||||
"growth_interval": self._growth_interval,
|
||||
"_growth_tracker": self._get_growth_tracker()} if self._enabled else {}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
r"""
|
||||
Loads the scaler state. If this instance is disabled, :meth:`load_state_dict` is a no-op.
|
||||
|
||||
Args:
|
||||
state_dict(dict): scaler state. Should be an object returned from a call to :meth:`state_dict`.
|
||||
"""
|
||||
if not self._enabled:
|
||||
return
|
||||
|
||||
if len(state_dict) == 0:
|
||||
raise RuntimeError("The source state dict is empty, possibly because it was saved "
|
||||
"from a disabled instance of GradScaler.")
|
||||
|
||||
self._init_scale = state_dict["scale"]
|
||||
if self._scale is not None:
|
||||
self._scale.fill_(state_dict["scale"])
|
||||
self._growth_factor = state_dict["growth_factor"]
|
||||
self._backoff_factor = state_dict["backoff_factor"]
|
||||
self._growth_interval = state_dict["growth_interval"]
|
||||
self._init_growth_tracker = state_dict["_growth_tracker"]
|
||||
if self._growth_tracker is not None:
|
||||
self._growth_tracker.fill_(state_dict["_growth_tracker"])
|
||||
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
if self._enabled:
|
||||
assert len(self._per_optimizer_states) == 0, "A GradScaler instance may only be pickled at the beginning "\
|
||||
"of an iteration, or at the end after scaler.update()."
|
||||
# Pickling _scale and _growth_tracker Tensors directly triggers
|
||||
# "warnings.warn("pickle support for Storage will be removed in 1.5..."
|
||||
# so instead, we set the unpickled instance up to reinitialize them lazily.
|
||||
state['_init_scale'] = self.get_scale()
|
||||
state['_init_growth_tracker'] = self._get_growth_tracker()
|
||||
state['_scale'] = None
|
||||
state['_growth_tracker'] = None
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__dict__.update(state)
|
||||
|
||||
def _check_inf_per_device(self, optimizer):
|
||||
_scale, _ = self._check_scale_growth_tracker("_check_inf_per_device")
|
||||
|
||||
dummy_inv_scale = 1.0
|
||||
found_inf = 0.0
|
||||
|
||||
self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = \
|
||||
self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True)
|
||||
|
||||
return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]
|
||||
|
||||
def _found_inf_per_device(self, optimizer):
|
||||
return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]
|
|
@ -0,0 +1,12 @@
|
|||
import math
|
||||
|
||||
def _jit_set_profiling_mode(x): pass
|
||||
def _jit_set_profiling_executor(x): pass
|
||||
def _jit_override_can_fuse_on_cpu(x): pass
|
||||
def _jit_override_can_fuse_on_gpu(x): pass
|
||||
|
||||
def script(func):
|
||||
return func
|
||||
|
||||
inf = math.inf
|
||||
nan = math.nan
|
|
@ -0,0 +1,281 @@
|
|||
import jtorch
|
||||
from typing import List, Optional, Tuple, Iterable, Iterator, Mapping, Any, overload, TypeVar, Dict
|
||||
from typing_extensions import Self
|
||||
import jittor as jt
|
||||
from jtorch import make_module, Tensor, ModuleMisc, wrapper
|
||||
#from . import init
|
||||
from jittor import Function
|
||||
import operator
|
||||
import warnings
|
||||
|
||||
for k,v in jt.nn.__dict__.items():
|
||||
if callable(v):
|
||||
globals()[k] = wrapper(v)
|
||||
|
||||
for k,v in jt.nn.__dict__.items():
|
||||
if isinstance(v, type) and issubclass(v, jt.Module):
|
||||
globals()[k] = make_module(v)
|
||||
|
||||
from collections import OrderedDict
|
||||
from collections import abc as container_abcs
|
||||
|
||||
class Module(ModuleMisc, jt.Module):
|
||||
|
||||
def __call__(self, *args, **kw):
|
||||
return self.execute(*args, **kw)
|
||||
|
||||
def execute(self, *args, **kw):
|
||||
return self.forward(*args, **kw)
|
||||
|
||||
def get_submodule(self, target: str):
|
||||
if target == "":
|
||||
return self
|
||||
|
||||
atoms: List[str] = target.split(".")
|
||||
mod: jt.nn.Module = self
|
||||
|
||||
for item in atoms:
|
||||
if not hasattr(mod, item):
|
||||
raise AttributeError(mod._get_name() + " has no "
|
||||
"attribute `" + item + "`")
|
||||
|
||||
mod = getattr(mod, item)
|
||||
|
||||
if not isinstance(mod, jt.nn.Module):
|
||||
raise AttributeError("`" + item + "` is not "
|
||||
"an nn.Module")
|
||||
return mod
|
||||
|
||||
|
||||
|
||||
def Parameter(x:Tensor, requires_grad:bool=True) -> Tensor:
|
||||
x = x.clone()
|
||||
x.requires_grad = requires_grad
|
||||
x.retains_grad = requires_grad
|
||||
return x
|
||||
|
||||
def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False):
|
||||
return jt.nn.embedding(input, weight)
|
||||
|
||||
def dropout(x, p=0.5, training=False):
|
||||
return jt.nn.dropout(x, p, training)
|
||||
|
||||
|
||||
class Flatten(Module):
|
||||
''' Flattens the contiguous range of dimensions in a Var.
|
||||
:param start_dim: the first dimension to be flattened. Defaults: 1.
|
||||
:type start_dim: int
|
||||
:param end_dim: the last dimension to be flattened. Defaults: -1.
|
||||
:type end_dim: int
|
||||
'''
|
||||
def __init__(self, start_dim=1, end_dim=-1):
|
||||
self.start_dim = start_dim
|
||||
self.end_dim = end_dim
|
||||
|
||||
def forward(self, x) -> jt.Var:
|
||||
return x.flatten(self.start_dim, self.end_dim)
|
||||
|
||||
class _IncompatibleKeys:
|
||||
def __init__(self, missing_keys, unexpected_keys):
|
||||
self.missing_keys = missing_keys
|
||||
self.unexpected_keys = unexpected_keys
|
||||
|
||||
_BatchNorm = None
|
||||
|
||||
#from . import utils
|
||||
normalize = wrapper(jt.normalize)
|
||||
|
||||
T = TypeVar('T', bound=Module)
|
||||
|
||||
class ModuleDict(Module):
|
||||
_modules: Dict[str, Module] # type: ignore[assignment]
|
||||
|
||||
def __init__(self, modules: Optional[Mapping[str, Module]] = None) -> None:
|
||||
super().__init__()
|
||||
if modules is not None:
|
||||
self.update(modules)
|
||||
|
||||
def __getitem__(self, key: str) -> Module:
|
||||
return self._modules[key]
|
||||
|
||||
def __setitem__(self, key: str, module: Module) -> None:
|
||||
self.add_module(key, module)
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
del self._modules[key]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._modules)
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
return iter(self._modules)
|
||||
|
||||
def __contains__(self, key: str) -> bool:
|
||||
return key in self._modules
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Remove all items from the ModuleDict."""
|
||||
self._modules.clear()
|
||||
|
||||
def pop(self, key: str) -> Module:
|
||||
r"""Remove key from the ModuleDict and return its module.
|
||||
|
||||
Args:
|
||||
key (str): key to pop from the ModuleDict
|
||||
"""
|
||||
v = self[key]
|
||||
del self[key]
|
||||
return v
|
||||
|
||||
def keys(self) -> Iterable[str]:
|
||||
r"""Return an iterable of the ModuleDict keys."""
|
||||
return self._modules.keys()
|
||||
|
||||
def items(self) -> Iterable[Tuple[str, Module]]:
|
||||
r"""Return an iterable of the ModuleDict key/value pairs."""
|
||||
return self._modules.items()
|
||||
|
||||
def values(self) -> Iterable[Module]:
|
||||
r"""Return an iterable of the ModuleDict values."""
|
||||
return self._modules.values()
|
||||
|
||||
def update(self, modules: Mapping[str, Module]) -> None:
|
||||
r"""Update the :class:`~torch.nn.ModuleDict` with key-value pairs from a mapping, overwriting existing keys.
|
||||
|
||||
.. note::
|
||||
If :attr:`modules` is an ``OrderedDict``, a :class:`~torch.nn.ModuleDict`, or
|
||||
an iterable of key-value pairs, the order of new elements in it is preserved.
|
||||
|
||||
Args:
|
||||
modules (iterable): a mapping (dictionary) from string to :class:`~torch.nn.Module`,
|
||||
or an iterable of key-value pairs of type (string, :class:`~torch.nn.Module`)
|
||||
"""
|
||||
if not isinstance(modules, container_abcs.Iterable):
|
||||
raise TypeError("ModuleDict.update should be called with an "
|
||||
"iterable of key/value pairs, but got " +
|
||||
type(modules).__name__)
|
||||
|
||||
if isinstance(modules, (OrderedDict, ModuleDict, container_abcs.Mapping)):
|
||||
for key, module in modules.items():
|
||||
self[key] = module
|
||||
else:
|
||||
# modules here can be a list with two items
|
||||
for j, m in enumerate(modules):
|
||||
if not isinstance(m, container_abcs.Iterable):
|
||||
raise TypeError("ModuleDict update sequence element "
|
||||
"#" + str(j) + " should be Iterable; is" +
|
||||
type(m).__name__)
|
||||
if not len(m) == 2:
|
||||
raise ValueError("ModuleDict update sequence element "
|
||||
"#" + str(j) + " has length " + str(len(m)) +
|
||||
"; 2 is required")
|
||||
# modules can be Mapping (what it's typed at), or a list: [(name1, module1), (name2, module2)]
|
||||
# that's too cumbersome to type correctly with overloads, so we add an ignore here
|
||||
self[m[0]] = m[1] # type: ignore[assignment]
|
||||
|
||||
# remove forward alltogether to fallback on Module's _forward_unimplemented
|
||||
|
||||
|
||||
class ParameterList(Module):
|
||||
|
||||
def __init__(self, values: Optional[Iterable[Any]] = None) -> None:
|
||||
super().__init__()
|
||||
self._size = 0
|
||||
if values is not None:
|
||||
self += values
|
||||
|
||||
def _get_abs_string_index(self, idx):
|
||||
"""Get the absolute index for the list of modules."""
|
||||
idx = operator.index(idx)
|
||||
if not (-len(self) <= idx < len(self)):
|
||||
raise IndexError(f'index {idx} is out of range')
|
||||
if idx < 0:
|
||||
idx += len(self)
|
||||
return str(idx)
|
||||
|
||||
@overload
|
||||
def __getitem__(self, idx: int) -> Any:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __getitem__(self: T, idx: slice) -> T:
|
||||
...
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if isinstance(idx, slice):
|
||||
start, stop, step = idx.indices(len(self))
|
||||
out = self.__class__()
|
||||
for i in range(start, stop, step):
|
||||
out.append(self[i])
|
||||
return out
|
||||
else:
|
||||
idx = self._get_abs_string_index(idx)
|
||||
return getattr(self, str(idx))
|
||||
|
||||
def __setitem__(self, idx: int, param: Any) -> None:
|
||||
# Note that all other function that add an entry to the list part of
|
||||
# the ParameterList end up here. So this is the only place where we need
|
||||
# to wrap things into Parameter if needed.
|
||||
# Objects added via setattr() are not in the list part and thus won't
|
||||
# call into this function.
|
||||
idx = self._get_abs_string_index(idx)
|
||||
if isinstance(param, jt.Var) and not isinstance(param, Parameter):
|
||||
param = Parameter(param)
|
||||
return setattr(self, str(idx), param)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self._size
|
||||
|
||||
def __iter__(self) -> Iterator[Any]:
|
||||
return iter(self[i] for i in range(len(self)))
|
||||
|
||||
def __iadd__(self, parameters: Iterable[Any]) -> Self:
|
||||
return self.extend(parameters)
|
||||
|
||||
def __dir__(self):
|
||||
keys = super().__dir__()
|
||||
keys = [key for key in keys if not key.isdigit()]
|
||||
return keys
|
||||
|
||||
def append(self, value: Any) -> 'ParameterList':
|
||||
"""Append a given value at the end of the list.
|
||||
|
||||
Args:
|
||||
value (Any): value to append
|
||||
"""
|
||||
new_idx = len(self)
|
||||
self._size += 1
|
||||
self[new_idx] = value
|
||||
return self
|
||||
|
||||
def extend(self, values: Iterable[Any]) -> Self:
|
||||
"""Append values from a Python iterable to the end of the list.
|
||||
|
||||
Args:
|
||||
values (iterable): iterable of values to append
|
||||
"""
|
||||
# Tensor is an iterable but we never want to unpack it here
|
||||
if not isinstance(values, container_abcs.Iterable) or isinstance(values, jt.Var):
|
||||
raise TypeError("ParameterList.extend should be called with an "
|
||||
"iterable, but got " + type(values).__name__)
|
||||
for value in values:
|
||||
self.append(value)
|
||||
return self
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
child_lines = []
|
||||
for k, p in enumerate(self):
|
||||
if isinstance(p, jt.Var):
|
||||
size_str = 'x'.join(str(size) for size in p.size())
|
||||
parastr = '{} containing: [{} of size {}{}]'.format(
|
||||
"Parameter" if isinstance(p, Parameter) else "Tensor",
|
||||
p.dtype, size_str, "cuda" if jt.flags.use_cuda else "cpu")
|
||||
child_lines.append(' (' + str(k) + '): ' + parastr)
|
||||
else:
|
||||
child_lines.append(' (' + str(k) + '): Object of type: ' + type(p).__name__)
|
||||
|
||||
tmpstr = '\n'.join(child_lines)
|
||||
return tmpstr
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise RuntimeError('ParameterList should not be called.')
|
|
@ -0,0 +1,16 @@
|
|||
import jittor as jt
|
||||
|
||||
for k,v in jt.nn.init.__dict__.items():
|
||||
if callable(v):
|
||||
globals()[k] = v
|
||||
|
||||
|
||||
normal = gauss
|
||||
normal_ = gauss_
|
||||
xavier_normal = xavier_gauss
|
||||
xavier_normal_ = xavier_gauss_
|
||||
zeros_ = zero_
|
||||
|
||||
|
||||
jt.Var.normal_ = normal_
|
||||
|
|
@ -0,0 +1 @@
|
|||
from . import rnn
|
|
@ -0,0 +1,20 @@
|
|||
import jittor as jt
|
||||
|
||||
PackedSequence = None
|
||||
|
||||
def pad_sequence(sequences,batch_first=False,padding_value=0.0):
|
||||
max_f = max([len(s) for s in sequences])
|
||||
# max_f = 512
|
||||
b = len(sequences)
|
||||
if batch_first:
|
||||
ret = sequences[0].new_full([b,max_f,]+list(sequences[0].shape[1:]),padding_value)
|
||||
for i,s in enumerate(sequences):
|
||||
ret[i,:len(s)] = s
|
||||
else:
|
||||
ret = sequences[0].new_full([max_f,b,]+list(sequences[0].shape[1:]),padding_value)
|
||||
for i,s in enumerate(sequences):
|
||||
ret[:len(s),i] = s
|
||||
# print(ret.shape)
|
||||
# ret = ret[:,:406]
|
||||
return ret
|
||||
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,102 @@
|
|||
|
||||
#include "pyjt/py_obj_holder.h"
|
||||
#include "utils/str_utils.h"
|
||||
#include "jtorch_core.h"
|
||||
#include "graph.h"
|
||||
#include "grad.h"
|
||||
#include "ops/op_register.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
void pyjt_def_all(PyObject* m);
|
||||
|
||||
EXTERN_LIB void setter_use_cuda(int value);
|
||||
|
||||
Device::Device(const string& name, int ordinal) : name(name) {
|
||||
if (startswith(name, "cpu"))
|
||||
setter_use_cuda(0);
|
||||
else
|
||||
setter_use_cuda(1);
|
||||
}
|
||||
|
||||
unordered_map<int64, VarPtr> grad_backup;
|
||||
EXTERN_LIB void (*_var_free_hook)(Var*);
|
||||
EXTERN_LIB unordered_map<int64, VarPtr>* _grad_backup_ptr;
|
||||
|
||||
void jtorch_var_free_hook(Var* v) {
|
||||
auto iter = grad_backup.find(v->id);
|
||||
if (iter != grad_backup.end()) {
|
||||
grad_backup.erase(iter);
|
||||
}
|
||||
}
|
||||
|
||||
void jtorch_init() {
|
||||
_var_free_hook = &jtorch_var_free_hook;
|
||||
_grad_backup_ptr = &grad_backup;
|
||||
}
|
||||
|
||||
inline static VarPtr& get_grad(Var* v) {
|
||||
return grad_backup[v->id];
|
||||
}
|
||||
static auto make_binary = get_op_info("binary")
|
||||
.get_constructor<VarPtr, Var*, Var*, NanoString>();
|
||||
|
||||
inline static void add_grad(VarPtr& a, VarPtr&& b) {
|
||||
if (!a) a = move(b);
|
||||
else {
|
||||
a = make_binary(a, b, ns_add);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void grad_set(VarHolder* x, Maybe<VarHolder> v) {
|
||||
if (!v) {
|
||||
grad_del(x);
|
||||
return;
|
||||
}
|
||||
grad_backup[x->var->id] = v.ptr->var;
|
||||
}
|
||||
|
||||
Maybe<VarHolder> grad_get(VarHolder* x) {
|
||||
auto iter = grad_backup.find(x->var->id);
|
||||
if (iter != grad_backup.end()) {
|
||||
if (!iter->second.ptr) return nullptr;
|
||||
return new VarHolder(iter->second.ptr);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void grad_del(VarHolder* x) {
|
||||
auto iter = grad_backup.find(x->var->id);
|
||||
if (iter != grad_backup.end())
|
||||
grad_backup.erase(iter);
|
||||
}
|
||||
|
||||
void backward(VarHolder* x) {
|
||||
vector<Node*> gnodes({x->var});
|
||||
bfs_backward(gnodes, [&](Node* node) {
|
||||
if (node->is_stop_grad())
|
||||
return false;
|
||||
return true;
|
||||
});
|
||||
vector<Var*> targets;
|
||||
for (auto* node : gnodes) {
|
||||
if (node->is_var() && node->flags.get(NodeFlags::_th_require_grad))
|
||||
targets.push_back(node->var());
|
||||
}
|
||||
auto grads = grad(x->var, targets);
|
||||
for (int i=0; i<targets.size(); i++) {
|
||||
auto& gptr = get_grad(targets[i]);
|
||||
add_grad(gptr, move(grads[i]));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
static void init_module(PyModuleDef* mdef, PyObject* m) {
|
||||
jittor::jtorch_init();
|
||||
mdef->m_doc = "Inner c++ core of jtorch";
|
||||
jittor::pyjt_def_all(m);
|
||||
}
|
||||
PYJT_MODULE_INIT(jtorch_core);
|
|
@ -0,0 +1,40 @@
|
|||
#pragma once
|
||||
#include "common.h"
|
||||
#include "var_holder.h"
|
||||
#include "misc/fast_shared_ptr.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
// @pyjt(device)
|
||||
// @attrs(heaptype)
|
||||
struct Device {
|
||||
string name;
|
||||
|
||||
// @pyjt(__init__)
|
||||
Device(const string& name, int ordinal=0);
|
||||
// @pyjt(__get__type, __str__)
|
||||
inline string get_type() {return name;}
|
||||
// @pyjt(__get__index)
|
||||
inline int index() {return 0;}
|
||||
};
|
||||
|
||||
// @pyjt(backward)
|
||||
void backward(VarHolder* x);
|
||||
|
||||
// @pyjt(grad_set)
|
||||
void grad_set(VarHolder* x, Maybe<VarHolder> v);
|
||||
// @pyjt(grad_get)
|
||||
Maybe<VarHolder> grad_get(VarHolder* x);
|
||||
// @pyjt(grad_del)
|
||||
void grad_del(VarHolder* x);
|
||||
|
||||
// @pyjt(retain_grad_set)
|
||||
inline void retain_grad_set(VarHolder* x, bool v) {
|
||||
x->var->flags.set(NodeFlags::_th_require_grad, v);
|
||||
}
|
||||
// @pyjt(retain_grad_get)
|
||||
inline bool retain_grad_get(VarHolder* x) {
|
||||
return x->var->flags.get(NodeFlags::_th_require_grad);
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,25 @@
|
|||
import unittest
|
||||
import numpy as np
|
||||
import torch
|
||||
import jittor as jt
|
||||
|
||||
class TestConflictFunc(unittest.TestCase):
|
||||
def test_max(self):
|
||||
a = torch.Tensor([1,4,2])
|
||||
assert a.max() == 4
|
||||
v, k = a.max(dim=0)
|
||||
assert v==4 and k==1
|
||||
|
||||
def test_argsort(self):
|
||||
a = torch.Tensor([1,4,2])
|
||||
k = a.argsort()
|
||||
assert jt.all_equal(k, [0,2,1])
|
||||
|
||||
with jt.flag_scope(th_mode=0):
|
||||
k, v = a.argsort()
|
||||
assert jt.all_equal(k, [0,2,1])
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -0,0 +1,58 @@
|
|||
import unittest
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
class TestFunction(unittest.TestCase):
|
||||
def test_example1(self):
|
||||
import jtorch
|
||||
from jtorch import Function
|
||||
|
||||
class MyFunc(Function):
|
||||
@staticmethod
|
||||
def forward(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
return x*y, x/y
|
||||
|
||||
@staticmethod
|
||||
def backward(self, grad0, grad1):
|
||||
return grad0 * self.y, grad1 * self.x
|
||||
|
||||
a = jtorch.array(3.0)
|
||||
a.requires_grad = True
|
||||
b = jtorch.array(4.0)
|
||||
b.requires_grad = True
|
||||
func = MyFunc.apply
|
||||
c,d = func(a, b)
|
||||
(c+d*3).backward()
|
||||
assert a.grad.data == 4
|
||||
assert b.grad.data == 9
|
||||
|
||||
def test_example2(self):
|
||||
import jtorch as jt
|
||||
from jtorch import Function
|
||||
|
||||
class MyFunc(Function):
|
||||
@staticmethod
|
||||
def forward(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
return x*y, x/y
|
||||
|
||||
@staticmethod
|
||||
def backward(self, grad0, grad1):
|
||||
assert grad1 is None
|
||||
return grad0 * self.y, None
|
||||
a = jt.array(3.0)
|
||||
a.requires_grad = True
|
||||
b = jt.array(4.0)
|
||||
b.requires_grad = True
|
||||
func = MyFunc.apply
|
||||
c,d = func(a, b)
|
||||
d.stop_grad()
|
||||
da, db = jt.grad(c+d*3, [a, b])
|
||||
assert da.data == 4
|
||||
assert db.data == 0
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -0,0 +1,24 @@
|
|||
import unittest
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
class TestMisc(unittest.TestCase):
|
||||
def test_update_grad(self):
|
||||
class Net(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.a = torch.nn.Parameter(torch.Tensor([1.0, 2.0]))
|
||||
net = Net()
|
||||
assert(net.a.requires_grad)
|
||||
net.load_state_dict({"a": torch.Tensor([3.0, 4.0])})
|
||||
assert(net.a.requires_grad)
|
||||
|
||||
def test_reshape(self):
|
||||
a = torch.ones(3,3)
|
||||
a.requires_grad = True
|
||||
b = torch.reshape(a, [9])
|
||||
assert b.requires_grad == True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -0,0 +1,56 @@
|
|||
import unittest
|
||||
import numpy as np
|
||||
import os
|
||||
import subprocess as sp
|
||||
import sys
|
||||
|
||||
def check_two(cmd, parser=None, checker=None):
|
||||
jtorch_out = sp.getoutput(cmd)
|
||||
print("=========JTORCH OUT==========")
|
||||
print(jtorch_out)
|
||||
torch_out = sp.getoutput("PYTHONPATH= "+cmd)
|
||||
print("=========TORCH OUT==========")
|
||||
print(torch_out)
|
||||
if parser:
|
||||
torch_out = parser(torch_out)
|
||||
jtorch_out = parser(jtorch_out)
|
||||
if checker:
|
||||
checker(torch_out, jtorch_out)
|
||||
else:
|
||||
assert torch_out == jtorch_out
|
||||
return jtorch_out, torch_out
|
||||
|
||||
jtorch_path = os.path.join(os.path.dirname(__file__), "..")
|
||||
# come from https://pytorch.org/tutorials/beginner/pytorch_with_examples.html
|
||||
class TestTutorial(unittest.TestCase):
|
||||
def test_auto_grad1(self):
|
||||
check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad1.py",
|
||||
parser=lambda s: np.array(s.split())[[-10,-8,-5,-2]].astype(float),
|
||||
checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-4))
|
||||
def test_auto_grad2(self):
|
||||
check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad2.py",
|
||||
parser=lambda s: np.array(s.split())[[-10,-8,-5,-2]].astype(float),
|
||||
checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-4))
|
||||
def test_auto_grad3(self):
|
||||
check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad3.py",
|
||||
parser=lambda s: np.array(s.split())[[-9,-7,-4,-2]].astype(float),
|
||||
checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-4))
|
||||
def test_auto_grad4(self):
|
||||
check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad4.py",
|
||||
parser=lambda s: np.array(s.split())[[-10,-8,-5,-2]].astype(float),
|
||||
checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-4))
|
||||
def test_auto_grad5(self):
|
||||
check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad5_optim.py",
|
||||
parser=lambda s: np.array(s.split())[[-10,-8,-5,-2]].astype(float),
|
||||
checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-2))
|
||||
def test_auto_grad6(self):
|
||||
check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad6_module.py",
|
||||
parser=lambda s: np.array(s.split())[[-10,-8,-5,-2]].astype(float),
|
||||
checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-4))
|
||||
def test_auto_grad7(self):
|
||||
check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad7_dynet.py",
|
||||
parser=lambda s: np.array(s.split())[[-13,-10,-7,-3]].astype(float),
|
||||
checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-2))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -0,0 +1,44 @@
|
|||
import torch
|
||||
import math
|
||||
|
||||
dtype = torch.float
|
||||
device = torch.device("cpu")
|
||||
# device = torch.device("cuda:0") # Uncomment this to run on GPU
|
||||
|
||||
# Create random input and output data
|
||||
x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
|
||||
y = torch.sin(x)
|
||||
|
||||
# Randomly initialize weights
|
||||
a = torch.randn((), device=device, dtype=dtype)
|
||||
b = torch.randn((), device=device, dtype=dtype)
|
||||
c = torch.randn((), device=device, dtype=dtype)
|
||||
d = torch.randn((), device=device, dtype=dtype)
|
||||
|
||||
learning_rate = 1e-6
|
||||
for t in range(20000):
|
||||
# Forward pass: compute predicted y
|
||||
y_pred = a + b * x + c * x ** 2 + d * x ** 3
|
||||
|
||||
# Compute and print loss
|
||||
loss = (y_pred - y).pow(2).sum().item()
|
||||
if t % 1000 == 999:
|
||||
print(t, loss)
|
||||
|
||||
# Backprop to compute gradients of a, b, c, d with respect to loss
|
||||
grad_y_pred = 2.0 * (y_pred - y)
|
||||
grad_a = grad_y_pred.sum()
|
||||
grad_b = (grad_y_pred * x).sum()
|
||||
grad_c = (grad_y_pred * x ** 2).sum()
|
||||
grad_d = (grad_y_pred * x ** 3).sum()
|
||||
|
||||
# Update weights using gradient descent
|
||||
a -= learning_rate * grad_a
|
||||
b -= learning_rate * grad_b
|
||||
c -= learning_rate * grad_c
|
||||
d -= learning_rate * grad_d
|
||||
# print(t, torch.liveness_info())
|
||||
# torch.sync_all()
|
||||
|
||||
|
||||
print(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3')
|
|
@ -0,0 +1,60 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
import torch
|
||||
import math
|
||||
|
||||
dtype = torch.float
|
||||
device = torch.device("cpu")
|
||||
# device = torch.device("cuda:0") # Uncomment this to run on GPU
|
||||
|
||||
# Create Tensors to hold input and outputs.
|
||||
# By default, requires_grad=False, which indicates that we do not need to
|
||||
# compute gradients with respect to these Tensors during the backward pass.
|
||||
x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
|
||||
y = torch.sin(x)
|
||||
|
||||
# Create random Tensors for weights. For a third order polynomial, we need
|
||||
# 4 weights: y = a + b x + c x^2 + d x^3
|
||||
# Setting requires_grad=True indicates that we want to compute gradients with
|
||||
# respect to these Tensors during the backward pass.
|
||||
a = torch.randn((), device=device, dtype=dtype, requires_grad=True)
|
||||
b = torch.randn((), device=device, dtype=dtype, requires_grad=True)
|
||||
c = torch.randn((), device=device, dtype=dtype, requires_grad=True)
|
||||
d = torch.randn((), device=device, dtype=dtype, requires_grad=True)
|
||||
|
||||
learning_rate = 1e-6
|
||||
for t in range(20000):
|
||||
# Forward pass: compute predicted y using operations on Tensors.
|
||||
y_pred = a + b * x + c * x ** 2 + d * x ** 3
|
||||
# print(y_pred.requires_grad)
|
||||
# y_pred.requires_grad = False
|
||||
|
||||
# Compute and print loss using operations on Tensors.
|
||||
# Now loss is a Tensor of shape (1,)
|
||||
# loss.item() gets the scalar value held in the loss.
|
||||
loss = (y_pred - y).pow(2).sum()
|
||||
if t % 1000 == 990:
|
||||
print(t, loss.item())
|
||||
|
||||
# Use autograd to compute the backward pass. This call will compute the
|
||||
# gradient of loss with respect to all Tensors with requires_grad=True.
|
||||
# After this call a.grad, b.grad. c.grad and d.grad will be Tensors holding
|
||||
# the gradient of the loss with respect to a, b, c, d respectively.
|
||||
# torch.backward(loss)
|
||||
loss.backward()
|
||||
|
||||
# Manually update weights using gradient descent. Wrap in torch.no_grad()
|
||||
# because weights have requires_grad=True, but we don't need to track this
|
||||
# in autograd.
|
||||
with torch.no_grad():
|
||||
a -= learning_rate * a.grad
|
||||
b -= learning_rate * b.grad
|
||||
c -= learning_rate * c.grad
|
||||
d -= learning_rate * d.grad
|
||||
|
||||
# Manually zero the gradients after updating weights
|
||||
a.grad = None
|
||||
b.grad = None
|
||||
c.grad = None
|
||||
d.grad = None
|
||||
|
||||
print(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3')
|
|
@ -0,0 +1,85 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
import torch
|
||||
import math
|
||||
|
||||
|
||||
class LegendrePolynomial3(torch.autograd.Function):
|
||||
"""
|
||||
We can implement our own custom autograd Functions by subclassing
|
||||
torch.autograd.Function and implementing the forward and backward passes
|
||||
which operate on Tensors.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
"""
|
||||
In the forward pass we receive a Tensor containing the input and return
|
||||
a Tensor containing the output. ctx is a context object that can be used
|
||||
to stash information for backward computation. You can cache arbitrary
|
||||
objects for use in the backward pass using the ctx.save_for_backward method.
|
||||
"""
|
||||
ctx.save_for_backward(input)
|
||||
return 0.5 * (5 * input ** 3 - 3 * input)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
"""
|
||||
In the backward pass we receive a Tensor containing the gradient of the loss
|
||||
with respect to the output, and we need to compute the gradient of the loss
|
||||
with respect to the input.
|
||||
"""
|
||||
input, = ctx.saved_tensors
|
||||
return grad_output * 1.5 * (5 * input ** 2 - 1)
|
||||
|
||||
|
||||
dtype = torch.float
|
||||
device = torch.device("cpu")
|
||||
# device = torch.device("cuda:0") # Uncomment this to run on GPU
|
||||
|
||||
# Create Tensors to hold input and outputs.
|
||||
# By default, requires_grad=False, which indicates that we do not need to
|
||||
# compute gradients with respect to these Tensors during the backward pass.
|
||||
x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
|
||||
y = torch.sin(x)
|
||||
|
||||
# Create random Tensors for weights. For this example, we need
|
||||
# 4 weights: y = a + b * P3(c + d * x), these weights need to be initialized
|
||||
# not too far from the correct result to ensure convergence.
|
||||
# Setting requires_grad=True indicates that we want to compute gradients with
|
||||
# respect to these Tensors during the backward pass.
|
||||
a = torch.full((), 0.0, device=device, dtype=dtype, requires_grad=True)
|
||||
b = torch.full((), -1.0, device=device, dtype=dtype, requires_grad=True)
|
||||
c = torch.full((), 0.0, device=device, dtype=dtype, requires_grad=True)
|
||||
d = torch.full((), 0.3, device=device, dtype=dtype, requires_grad=True)
|
||||
|
||||
learning_rate = 5e-6
|
||||
for t in range(2000):
|
||||
# To apply our Function, we use Function.apply method. We alias this as 'P3'.
|
||||
P3 = LegendrePolynomial3.apply
|
||||
|
||||
# Forward pass: compute predicted y using operations; we compute
|
||||
# P3 using our custom autograd operation.
|
||||
y_pred = a + b * P3(c + d * x)
|
||||
|
||||
# Compute and print loss
|
||||
loss = (y_pred - y).pow(2).sum()
|
||||
if t % 100 == 99:
|
||||
print(t, loss.item())
|
||||
|
||||
# Use autograd to compute the backward pass.
|
||||
loss.backward()
|
||||
|
||||
# Update weights using gradient descent
|
||||
with torch.no_grad():
|
||||
a -= learning_rate * a.grad
|
||||
b -= learning_rate * b.grad
|
||||
c -= learning_rate * c.grad
|
||||
d -= learning_rate * d.grad
|
||||
|
||||
# Manually zero the gradients after updating weights
|
||||
a.grad = None
|
||||
b.grad = None
|
||||
c.grad = None
|
||||
d.grad = None
|
||||
|
||||
print(f'Result: y = {a.item()} + {b.item()} * P3( {c.item()} + {d.item()} x)')
|
|
@ -0,0 +1,71 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
import torch
|
||||
import math
|
||||
|
||||
|
||||
# Create Tensors to hold input and outputs.
|
||||
x = torch.linspace(-math.pi, math.pi, 2000)
|
||||
y = torch.sin(x)
|
||||
|
||||
# For this example, the output y is a linear function of (x, x^2, x^3), so
|
||||
# we can consider it as a linear layer neural network. Let's prepare the
|
||||
# tensor (x, x^2, x^3).
|
||||
p = torch.tensor([1, 2, 3])
|
||||
xx = x.unsqueeze(-1).pow(p)
|
||||
|
||||
# In the above code, x.unsqueeze(-1) has shape (2000, 1), and p has shape
|
||||
# (3,), for this case, broadcasting semantics will apply to obtain a tensor
|
||||
# of shape (2000, 3)
|
||||
|
||||
# Use the nn package to define our model as a sequence of layers. nn.Sequential
|
||||
# is a Module which contains other Modules, and applies them in sequence to
|
||||
# produce its output. The Linear Module computes output from input using a
|
||||
# linear function, and holds internal Tensors for its weight and bias.
|
||||
# The Flatten layer flatens the output of the linear layer to a 1D tensor,
|
||||
# to match the shape of `y`.
|
||||
model = torch.nn.Sequential(
|
||||
torch.nn.Linear(3, 1),
|
||||
torch.nn.Flatten(0, 1)
|
||||
)
|
||||
|
||||
# The nn package also contains definitions of popular loss functions; in this
|
||||
# case we will use Mean Squared Error (MSE) as our loss function.
|
||||
loss_fn = torch.nn.MSELoss(reduction='sum')
|
||||
# print(model[0].weight.requires_grad)
|
||||
|
||||
learning_rate = 1e-6
|
||||
for t in range(8000):
|
||||
|
||||
# Forward pass: compute predicted y by passing x to the model. Module objects
|
||||
# override the __call__ operator so you can call them like functions. When
|
||||
# doing so you pass a Tensor of input data to the Module and it produces
|
||||
# a Tensor of output data.
|
||||
y_pred = model(xx)
|
||||
|
||||
# Compute and print loss. We pass Tensors containing the predicted and true
|
||||
# values of y, and the loss function returns a Tensor containing the
|
||||
# loss.
|
||||
loss = loss_fn(y_pred, y)
|
||||
if t % 1000 == 999:
|
||||
print(t, loss.item())
|
||||
|
||||
# Zero the gradients before running the backward pass.
|
||||
model.zero_grad()
|
||||
|
||||
# Backward pass: compute gradient of the loss with respect to all the learnable
|
||||
# parameters of the model. Internally, the parameters of each Module are stored
|
||||
# in Tensors with requires_grad=True, so this call will compute gradients for
|
||||
# all learnable parameters in the model.
|
||||
loss.backward()
|
||||
|
||||
# Update the weights using gradient descent. Each parameter is a Tensor, so
|
||||
# we can access its gradients like we did before.
|
||||
with torch.no_grad():
|
||||
for param in model.parameters():
|
||||
param -= learning_rate * param.grad
|
||||
|
||||
# You can access the first layer of `model` like accessing the first item of a list
|
||||
linear_layer = model[0]
|
||||
|
||||
# For linear layer, its parameters are stored as `weight` and `bias`.
|
||||
print(f'Result: y = {linear_layer.bias.item()} + {linear_layer.weight[:, 0].item()} x + {linear_layer.weight[:, 1].item()} x^2 + {linear_layer.weight[:, 2].item()} x^3')
|
|
@ -0,0 +1,53 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
import torch
|
||||
import math
|
||||
|
||||
|
||||
# Create Tensors to hold input and outputs.
|
||||
x = torch.linspace(-math.pi, math.pi, 2000)
|
||||
y = torch.sin(x)
|
||||
|
||||
# Prepare the input tensor (x, x^2, x^3).
|
||||
p = torch.tensor([1, 2, 3])
|
||||
xx = x.unsqueeze(-1).pow(p)
|
||||
|
||||
# Use the nn package to define our model and loss function.
|
||||
model = torch.nn.Sequential(
|
||||
torch.nn.Linear(3, 1),
|
||||
torch.nn.Flatten(0, 1)
|
||||
)
|
||||
loss_fn = torch.nn.MSELoss(reduction='sum')
|
||||
|
||||
# Use the optim package to define an Optimizer that will update the weights of
|
||||
# the model for us. Here we will use RMSprop; the optim package contains many other
|
||||
# optimization algorithms. The first argument to the RMSprop constructor tells the
|
||||
# optimizer which Tensors it should update.
|
||||
learning_rate = 1e-3
|
||||
optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate)
|
||||
for t in range(8000):
|
||||
# Forward pass: compute predicted y by passing x to the model.
|
||||
y_pred = model(xx)
|
||||
|
||||
# Compute and print loss.
|
||||
loss = loss_fn(y_pred, y)
|
||||
if t % 1000 == 999:
|
||||
print(t, loss.item())
|
||||
|
||||
# Before the backward pass, use the optimizer object to zero all of the
|
||||
# gradients for the variables it will update (which are the learnable
|
||||
# weights of the model). This is because by default, gradients are
|
||||
# accumulated in buffers( i.e, not overwritten) whenever .backward()
|
||||
# is called. Checkout docs of torch.autograd.backward for more details.
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Backward pass: compute gradient of the loss with respect to model
|
||||
# parameters
|
||||
loss.backward()
|
||||
|
||||
# Calling the step function on an Optimizer makes an update to its
|
||||
# parameters
|
||||
optimizer.step()
|
||||
|
||||
|
||||
linear_layer = model[0]
|
||||
print(f'Result: y = {linear_layer.bias.item()} + {linear_layer.weight[:, 0].item()} x + {linear_layer.weight[:, 1].item()} x^2 + {linear_layer.weight[:, 2].item()} x^3')
|
|
@ -0,0 +1,59 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
import torch
|
||||
import math
|
||||
|
||||
|
||||
class Polynomial3(torch.nn.Module):
|
||||
def __init__(self):
|
||||
"""
|
||||
In the constructor we instantiate four parameters and assign them as
|
||||
member parameters.
|
||||
"""
|
||||
super().__init__()
|
||||
self.a = torch.nn.Parameter(torch.randn(()))
|
||||
self.b = torch.nn.Parameter(torch.randn(()))
|
||||
self.c = torch.nn.Parameter(torch.randn(()))
|
||||
self.d = torch.nn.Parameter(torch.randn(()))
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
In the forward function we accept a Tensor of input data and we must return
|
||||
a Tensor of output data. We can use Modules defined in the constructor as
|
||||
well as arbitrary operators on Tensors.
|
||||
"""
|
||||
return self.a + self.b * x + self.c * x ** 2 + self.d * x ** 3
|
||||
|
||||
def string(self):
|
||||
"""
|
||||
Just like any class in Python, you can also define custom method on PyTorch modules
|
||||
"""
|
||||
return f'y = {self.a.item()} + {self.b.item()} x + {self.c.item()} x^2 + {self.d.item()} x^3'
|
||||
|
||||
|
||||
# Create Tensors to hold input and outputs.
|
||||
x = torch.linspace(-math.pi, math.pi, 2000)
|
||||
y = torch.sin(x)
|
||||
|
||||
# Construct our model by instantiating the class defined above
|
||||
model = Polynomial3()
|
||||
|
||||
# Construct our loss function and an Optimizer. The call to model.parameters()
|
||||
# in the SGD constructor will contain the learnable parameters (defined
|
||||
# with torch.nn.Parameter) which are members of the model.
|
||||
criterion = torch.nn.MSELoss(reduction='sum')
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=1e-6)
|
||||
for t in range(8000):
|
||||
# Forward pass: Compute predicted y by passing x to the model
|
||||
y_pred = model(x)
|
||||
|
||||
# Compute and print loss
|
||||
loss = criterion(y_pred, y)
|
||||
if t % 1000 == 999:
|
||||
print(t, loss.item())
|
||||
|
||||
# Zero gradients, perform a backward pass, and update the weights.
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
print(f'Result: {model.string()}')
|
|
@ -0,0 +1,69 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
import random
|
||||
import torch
|
||||
import math
|
||||
|
||||
|
||||
class DynamicNet(torch.nn.Module):
|
||||
def __init__(self):
|
||||
"""
|
||||
In the constructor we instantiate five parameters and assign them as members.
|
||||
"""
|
||||
super().__init__()
|
||||
self.a = torch.nn.Parameter(torch.randn(()))
|
||||
self.b = torch.nn.Parameter(torch.randn(()))
|
||||
self.c = torch.nn.Parameter(torch.randn(()))
|
||||
self.d = torch.nn.Parameter(torch.randn(()))
|
||||
self.e = torch.nn.Parameter(torch.randn(()))
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
For the forward pass of the model, we randomly choose either 4, 5
|
||||
and reuse the e parameter to compute the contribution of these orders.
|
||||
|
||||
Since each forward pass builds a dynamic computation graph, we can use normal
|
||||
Python control-flow operators like loops or conditional statements when
|
||||
defining the forward pass of the model.
|
||||
|
||||
Here we also see that it is perfectly safe to reuse the same parameter many
|
||||
times when defining a computational graph.
|
||||
"""
|
||||
y = self.a + self.b * x + self.c * x ** 2 + self.d * x ** 3
|
||||
for exp in range(4, random.randint(4, 6)):
|
||||
y = y + self.e * x ** exp
|
||||
return y
|
||||
|
||||
def string(self):
|
||||
"""
|
||||
Just like any class in Python, you can also define custom method on PyTorch modules
|
||||
"""
|
||||
return f'y = {self.a.item()} + {self.b.item()} x + {self.c.item()} x^2 + {self.d.item()} x^3 + {self.e.item()} x^4 ? + {self.e.item()} x^5 ?'
|
||||
|
||||
|
||||
# Create Tensors to hold input and outputs.
|
||||
x = torch.linspace(-math.pi, math.pi, 2000)
|
||||
y = torch.sin(x)
|
||||
|
||||
# Construct our model by instantiating the class defined above
|
||||
model = DynamicNet()
|
||||
|
||||
# Construct our loss function and an Optimizer. Training this strange model with
|
||||
# vanilla stochastic gradient descent is tough, so we use momentum
|
||||
criterion = torch.nn.MSELoss(reduction='sum')
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=1e-8, momentum=0.9)
|
||||
for t in range(60000):
|
||||
# Forward pass: Compute predicted y by passing x to the model
|
||||
y_pred = model(x)
|
||||
|
||||
# Compute and print loss
|
||||
loss = criterion(y_pred, y)
|
||||
if t % 2000 == 1999:
|
||||
print(t, loss.item())
|
||||
|
||||
# Zero gradients, perform a backward pass, and update the weights.
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
# print(torch.liveness_info())
|
||||
|
||||
print(f'Result: {model.string()}')
|
|
@ -0,0 +1,106 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
# from jtorch.utils import DataLoader
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision import datasets
|
||||
from torchvision.transforms import ToTensor
|
||||
|
||||
# Download training data from open datasets.
|
||||
training_data = datasets.FashionMNIST(
|
||||
root="data",
|
||||
train=True,
|
||||
download=True,
|
||||
transform=ToTensor(),
|
||||
)
|
||||
|
||||
# Download test data from open datasets.
|
||||
test_data = datasets.FashionMNIST(
|
||||
root="data",
|
||||
train=False,
|
||||
download=True,
|
||||
transform=ToTensor(),
|
||||
)
|
||||
|
||||
batch_size = 64
|
||||
|
||||
# Create data loaders.
|
||||
train_dataloader = DataLoader(training_data, batch_size=batch_size)
|
||||
test_dataloader = DataLoader(test_data, batch_size=batch_size)
|
||||
|
||||
print(len(train_dataloader))
|
||||
for X, y in test_dataloader:
|
||||
print(f"Shape of X [N, C, H, W]: {X.shape}")
|
||||
print(f"Shape of y: {y.shape} {y.dtype}")
|
||||
break
|
||||
|
||||
# Get cpu or gpu device for training.
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(f"Using {device} device")
|
||||
|
||||
# Define model
|
||||
class NeuralNetwork(nn.Module):
|
||||
def __init__(self):
|
||||
super(NeuralNetwork, self).__init__()
|
||||
self.flatten = nn.Flatten()
|
||||
self.linear_relu_stack = nn.Sequential(
|
||||
nn.Linear(28*28, 512),
|
||||
nn.ReLU(),
|
||||
nn.Linear(512, 512),
|
||||
nn.ReLU(),
|
||||
nn.Linear(512, 10)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.flatten(x)
|
||||
logits = self.linear_relu_stack(x)
|
||||
return logits
|
||||
|
||||
model = NeuralNetwork().to(device)
|
||||
print(model)
|
||||
|
||||
|
||||
loss_fn = nn.CrossEntropyLoss()
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
|
||||
|
||||
def train(dataloader, model, loss_fn, optimizer):
|
||||
size = len(dataloader.dataset)
|
||||
model.train()
|
||||
for batch, (X, y) in enumerate(dataloader):
|
||||
X, y = X.to(device), y.to(device)
|
||||
|
||||
# Compute prediction error
|
||||
pred = model(X)
|
||||
loss = loss_fn(pred, y)
|
||||
|
||||
# Backpropagation
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if batch % 100 == 0:
|
||||
loss, current = loss.item(), batch * len(X)
|
||||
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
|
||||
|
||||
def test(dataloader, model, loss_fn):
|
||||
size = len(dataloader.dataset)
|
||||
num_batches = len(dataloader)
|
||||
model.eval()
|
||||
test_loss, correct = 0, 0
|
||||
with torch.no_grad():
|
||||
for X, y in dataloader:
|
||||
X, y = X.to(device), y.to(device)
|
||||
pred = model(X)
|
||||
test_loss += loss_fn(pred, y).item()
|
||||
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
|
||||
test_loss /= num_batches
|
||||
correct /= size
|
||||
print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
|
||||
|
||||
|
||||
epochs = 5
|
||||
test(test_dataloader, model, loss_fn)
|
||||
for t in range(epochs):
|
||||
print(f"Epoch {t+1}\n-------------------------------")
|
||||
train(train_dataloader, model, loss_fn, optimizer)
|
||||
test(test_dataloader, model, loss_fn)
|
||||
print("Done!")
|
|
@ -0,0 +1,5 @@
|
|||
cpp_extension = None
|
||||
_flatten_dense_tensors = None
|
||||
_unflatten_dense_tensors = None
|
||||
|
||||
tensorboard = None
|
|
@ -0,0 +1,3 @@
|
|||
#TODO: Implement this
|
||||
_register_pytree_node = None
|
||||
_dict_flatten = None
|
|
@ -0,0 +1,8 @@
|
|||
detach_variable = None
|
||||
|
||||
|
||||
def checkpoint(
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
pass
|
|
@ -0,0 +1,137 @@
|
|||
import jittor as jt
|
||||
import jittor.dataset
|
||||
from jittor.dataset import Dataset as JDataset
|
||||
|
||||
from collections import namedtuple
|
||||
from typing import Any, Callable, Iterable, Optional, Sequence, Union
|
||||
|
||||
|
||||
class Dataset:
|
||||
def __getitem__(self, index):
|
||||
raise NotImplementedError
|
||||
|
||||
class IterableDataset:
|
||||
def __iter__(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class DataLoader(JDataset):
|
||||
def __init__(self, dataset,
|
||||
batch_size: Optional[int] = 1,
|
||||
shuffle: Optional[bool] = False,
|
||||
sampler = None,
|
||||
batch_sampler = None,
|
||||
num_workers: int = 0,
|
||||
collate_fn = None,
|
||||
pin_memory: bool = False,
|
||||
drop_last: bool = False,
|
||||
timeout: float = 0,
|
||||
worker_init_fn = None,
|
||||
multiprocessing_context=None,
|
||||
generator=None,
|
||||
*, prefetch_factor: int = 2,
|
||||
persistent_workers: bool = False,
|
||||
pin_memory_device: str = "") -> None:
|
||||
super().__init__(batch_size=batch_size,
|
||||
shuffle=shuffle,
|
||||
num_workers=num_workers,
|
||||
drop_last=drop_last)
|
||||
|
||||
unsupported_kwargs = {
|
||||
"batch_sampler": batch_sampler,
|
||||
"pin_memory": pin_memory,
|
||||
"timeout": timeout,
|
||||
"worker_init_fn": worker_init_fn,
|
||||
"multiprocessing_context": multiprocessing_context,
|
||||
"generator": generator,
|
||||
"persistent_workers": persistent_workers,
|
||||
"pin_memory_device": pin_memory_device
|
||||
}
|
||||
for kwarg, value in unsupported_kwargs.items():
|
||||
if value:
|
||||
jt.LOG.w(f"Not implemented Dataloader kwarg: {kwarg}")
|
||||
|
||||
self.dataset = dataset
|
||||
self.collate_fn = collate_fn
|
||||
self.sampler = sampler
|
||||
|
||||
if not isinstance(dataset, IterableDataset):
|
||||
self.total_len = len(dataset)
|
||||
else:
|
||||
# TODO: support multiple worker for iterable dataset
|
||||
assert(num_workers == 0)
|
||||
|
||||
def collate_batch(self, batch):
|
||||
if self.collate_fn is not None:
|
||||
return self.collate_fn(batch)
|
||||
else:
|
||||
return super().collate_batch(batch)
|
||||
|
||||
def __getitem__(self, i):
|
||||
return self.dataset[i]
|
||||
|
||||
def __iter__(self):
|
||||
if isinstance(self.dataset, IterableDataset):
|
||||
return self.inner_iter()
|
||||
else:
|
||||
return super().__iter__()
|
||||
|
||||
def inner_iter(self):
|
||||
current_batch = []
|
||||
|
||||
if jt.world_size > 1:
|
||||
assert self.batch_size % jt.world_size == 0, \
|
||||
f"IterableDataset does not support a batch size ({self.batch_size}) that is not evenly divisible by the number of processes f{jt.world_size}"
|
||||
real_batch_size = int(self.batch_size / jt.world_size)
|
||||
else:
|
||||
real_batch_size = self.batch_size
|
||||
|
||||
for element in self.dataset:
|
||||
current_batch.append(element)
|
||||
|
||||
if len(current_batch) == real_batch_size:
|
||||
current_batch = self.collate_batch(current_batch)
|
||||
current_batch = self.to_jittor(current_batch)
|
||||
yield current_batch
|
||||
current_batch = []
|
||||
|
||||
if not self.drop_last and len(current_batch) > 0:
|
||||
current_batch = self.collate_batch(current_batch)
|
||||
yield self.to_jittor(current_batch)
|
||||
|
||||
# def get_worker_info():
|
||||
# # always return the fake worker info
|
||||
# return namedtuple('WorkerInfo', 'id num_workers')(0, 1)
|
||||
|
||||
# class RandomSampler(jt.dataset.RandomSampler):
|
||||
# def __init__(self, dataset, generator=None, **kwargs):
|
||||
# super().__init__(dataset, **kwargs)
|
||||
|
||||
# def __iter__(self):
|
||||
# if getattr(self.dataset, "support_random_access", True):
|
||||
# return super().__iter__()
|
||||
# else:
|
||||
# self.dataset.shuffle()
|
||||
# return iter(range(self.dataset.__real_len__() if hasattr(self.dataset,"__real_len__") else self.dataset.__len__()))
|
||||
|
||||
# class DistributedSampler(jt.dataset.Sampler):
|
||||
# def __init__(self, sampler: RandomSampler):
|
||||
# assert(isinstance(sampler, RandomSampler))
|
||||
# self.sampler = sampler
|
||||
|
||||
# def set_epoch(self, epoch: int):
|
||||
# ### do nothing, let jittor's inner dataset handle
|
||||
# pass
|
||||
|
||||
# def __iter__(self):
|
||||
# return self.sampler.__iter__()
|
||||
|
||||
# def __len__(self):
|
||||
# return self.sampler.__len__()
|
||||
|
||||
# BatchSampler = jt.dataset.BatchSampler
|
||||
# Sampler = jt.dataset.Sampler
|
||||
# SequentialSampler = jt.dataset.SequentialSampler
|
||||
# SubsetRandomSampler = jt.dataset.SubsetRandomSampler
|
||||
|
||||
# TensorDataset = Dataset
|
|
@ -0,0 +1,9 @@
|
|||
from typing import Callable, Union
|
||||
Dtype = Union[Callable, str]
|
||||
|
||||
def get_string_dtype(dtype):
|
||||
if callable(dtype):
|
||||
dtype = dtype.__name__
|
||||
if not isinstance(dtype, str):
|
||||
raise ValueError(f"dtype is expected to be str, python type function, or jittor type function, but got {dtype}.")
|
||||
return dtype
|
|
@ -0,0 +1,34 @@
|
|||
import os
|
||||
import glob
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
home_path = os.path.join(os.path.dirname(__file__), "..", "..", "..")
|
||||
home_path = os.path.abspath(home_path)
|
||||
|
||||
def callback(func, path, exc_info):
|
||||
print(f"remove \"{path}\" failed.")
|
||||
|
||||
def rmtree(path):
|
||||
if os.path.isdir(path):
|
||||
print(f"remove \"{path}\" recursive.")
|
||||
shutil.rmtree(path, onerror=callback)
|
||||
|
||||
def remove_tmpfile():
|
||||
dist_file = home_path+"/dist"
|
||||
egg_file = glob.glob(home_path+"/**/*egg-info")
|
||||
rmtree(dist_file)
|
||||
for e in egg_file:
|
||||
rmtree(e)
|
||||
|
||||
def run_cmd(cmd):
|
||||
print("[CMD]", cmd)
|
||||
assert os.system(cmd)==0
|
||||
|
||||
os.chdir(home_path)
|
||||
remove_tmpfile()
|
||||
|
||||
run_cmd(f"{sys.executable} ./setup.py sdist")
|
||||
run_cmd(f"{sys.executable} -m twine upload dist/*")
|
||||
|
||||
remove_tmpfile()
|
|
@ -0,0 +1,46 @@
|
|||
import importlib.machinery
|
||||
import os
|
||||
|
||||
|
||||
def _download_file_from_remote_location(fpath: str, url: str) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def _is_remote_location_available() -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _get_extension_path(lib_name):
|
||||
|
||||
lib_dir = os.path.dirname(__file__)
|
||||
if os.name == "nt":
|
||||
# Register the main torchvision library location on the default DLL path
|
||||
import ctypes
|
||||
import sys
|
||||
|
||||
kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True)
|
||||
with_load_library_flags = hasattr(kernel32, "AddDllDirectory")
|
||||
prev_error_mode = kernel32.SetErrorMode(0x0001)
|
||||
|
||||
if with_load_library_flags:
|
||||
kernel32.AddDllDirectory.restype = ctypes.c_void_p
|
||||
|
||||
if sys.version_info >= (3, 8):
|
||||
os.add_dll_directory(lib_dir)
|
||||
elif with_load_library_flags:
|
||||
res = kernel32.AddDllDirectory(lib_dir)
|
||||
if res is None:
|
||||
err = ctypes.WinError(ctypes.get_last_error())
|
||||
err.strerror += f' Error adding "{lib_dir}" to the DLL directories.'
|
||||
raise err
|
||||
|
||||
kernel32.SetErrorMode(prev_error_mode)
|
||||
|
||||
loader_details = (importlib.machinery.ExtensionFileLoader, importlib.machinery.EXTENSION_SUFFIXES)
|
||||
|
||||
extfinder = importlib.machinery.FileFinder(lib_dir, loader_details)
|
||||
ext_specs = extfinder.find_spec(lib_name)
|
||||
if ext_specs is None:
|
||||
raise ImportError
|
||||
|
||||
return ext_specs.origin
|
|
@ -0,0 +1,9 @@
|
|||
from .mnist import EMNIST, FashionMNIST, KMNIST, MNIST, QMNIST
|
||||
|
||||
__all__ = (
|
||||
"EMNIST",
|
||||
"FashionMNIST",
|
||||
"QMNIST",
|
||||
"MNIST",
|
||||
"KMNIST",
|
||||
)
|
|
@ -0,0 +1,558 @@
|
|||
import codecs
|
||||
import os
|
||||
import os.path
|
||||
import shutil
|
||||
import string
|
||||
import sys
|
||||
import warnings
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
from urllib.error import URLError
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from .utils import check_integrity, download_and_extract_archive, extract_archive, verify_str_arg
|
||||
from .vision import VisionDataset
|
||||
|
||||
|
||||
class MNIST(VisionDataset):
|
||||
"""`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.
|
||||
|
||||
Args:
|
||||
root (string): Root directory of dataset where ``MNIST/raw/train-images-idx3-ubyte``
|
||||
and ``MNIST/raw/t10k-images-idx3-ubyte`` exist.
|
||||
train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``,
|
||||
otherwise from ``t10k-images-idx3-ubyte``.
|
||||
download (bool, optional): If True, downloads the dataset from the internet and
|
||||
puts it in root directory. If dataset is already downloaded, it is not
|
||||
downloaded again.
|
||||
transform (callable, optional): A function/transform that takes in an PIL image
|
||||
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
||||
target_transform (callable, optional): A function/transform that takes in the
|
||||
target and transforms it.
|
||||
"""
|
||||
|
||||
mirrors = [
|
||||
"http://yann.lecun.com/exdb/mnist/",
|
||||
"https://ossci-datasets.s3.amazonaws.com/mnist/",
|
||||
]
|
||||
|
||||
resources = [
|
||||
("train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"),
|
||||
("train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"),
|
||||
("t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"),
|
||||
("t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c"),
|
||||
]
|
||||
|
||||
training_file = "training.pt"
|
||||
test_file = "test.pt"
|
||||
classes = [
|
||||
"0 - zero",
|
||||
"1 - one",
|
||||
"2 - two",
|
||||
"3 - three",
|
||||
"4 - four",
|
||||
"5 - five",
|
||||
"6 - six",
|
||||
"7 - seven",
|
||||
"8 - eight",
|
||||
"9 - nine",
|
||||
]
|
||||
|
||||
@property
|
||||
def train_labels(self):
|
||||
warnings.warn("train_labels has been renamed targets")
|
||||
return self.targets
|
||||
|
||||
@property
|
||||
def test_labels(self):
|
||||
warnings.warn("test_labels has been renamed targets")
|
||||
return self.targets
|
||||
|
||||
@property
|
||||
def train_data(self):
|
||||
warnings.warn("train_data has been renamed data")
|
||||
return self.data
|
||||
|
||||
@property
|
||||
def test_data(self):
|
||||
warnings.warn("test_data has been renamed data")
|
||||
return self.data
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: str,
|
||||
train: bool = True,
|
||||
transform: Optional[Callable] = None,
|
||||
target_transform: Optional[Callable] = None,
|
||||
download: bool = False,
|
||||
) -> None:
|
||||
super().__init__(root, transform=transform, target_transform=target_transform)
|
||||
self.train = train # training set or test set
|
||||
|
||||
if self._check_legacy_exist():
|
||||
self.data, self.targets = self._load_legacy_data()
|
||||
return
|
||||
|
||||
if download:
|
||||
self.download()
|
||||
|
||||
if not self._check_exists():
|
||||
raise RuntimeError("Dataset not found. You can use download=True to download it")
|
||||
|
||||
self.data, self.targets = self._load_data()
|
||||
|
||||
def _check_legacy_exist(self):
|
||||
processed_folder_exists = os.path.exists(self.processed_folder)
|
||||
if not processed_folder_exists:
|
||||
return False
|
||||
|
||||
return all(
|
||||
check_integrity(os.path.join(self.processed_folder, file)) for file in (self.training_file, self.test_file)
|
||||
)
|
||||
|
||||
def _load_legacy_data(self):
|
||||
# This is for BC only. We no longer cache the data in a custom binary, but simply read from the raw data
|
||||
# directly.
|
||||
data_file = self.training_file if self.train else self.test_file
|
||||
return torch.load(os.path.join(self.processed_folder, data_file))
|
||||
|
||||
def _load_data(self):
|
||||
image_file = f"{'train' if self.train else 't10k'}-images-idx3-ubyte"
|
||||
data = read_image_file(os.path.join(self.raw_folder, image_file))
|
||||
|
||||
label_file = f"{'train' if self.train else 't10k'}-labels-idx1-ubyte"
|
||||
targets = read_label_file(os.path.join(self.raw_folder, label_file))
|
||||
|
||||
return data, targets
|
||||
|
||||
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
||||
"""
|
||||
Args:
|
||||
index (int): Index
|
||||
|
||||
Returns:
|
||||
tuple: (image, target) where target is index of the target class.
|
||||
"""
|
||||
img, target = self.data[index], int(self.targets[index])
|
||||
|
||||
# doing this so that it is consistent with all other datasets
|
||||
# to return a PIL Image
|
||||
img = Image.fromarray(img.numpy(), mode="L")
|
||||
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
|
||||
if self.target_transform is not None:
|
||||
target = self.target_transform(target)
|
||||
|
||||
return img, target
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.data)
|
||||
|
||||
@property
|
||||
def raw_folder(self) -> str:
|
||||
return os.path.join(self.root, self.__class__.__name__, "raw")
|
||||
|
||||
@property
|
||||
def processed_folder(self) -> str:
|
||||
return os.path.join(self.root, self.__class__.__name__, "processed")
|
||||
|
||||
@property
|
||||
def class_to_idx(self) -> Dict[str, int]:
|
||||
return {_class: i for i, _class in enumerate(self.classes)}
|
||||
|
||||
def _check_exists(self) -> bool:
|
||||
return all(
|
||||
check_integrity(os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0]))
|
||||
for url, _ in self.resources
|
||||
)
|
||||
|
||||
def download(self) -> None:
|
||||
"""Download the MNIST data if it doesn't exist already."""
|
||||
|
||||
if self._check_exists():
|
||||
return
|
||||
|
||||
os.makedirs(self.raw_folder, exist_ok=True)
|
||||
|
||||
# download files
|
||||
for filename, md5 in self.resources:
|
||||
for mirror in self.mirrors:
|
||||
url = f"{mirror}{filename}"
|
||||
try:
|
||||
print(f"Downloading {url}")
|
||||
download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5)
|
||||
except URLError as error:
|
||||
print(f"Failed to download (trying next):\n{error}")
|
||||
continue
|
||||
finally:
|
||||
print()
|
||||
break
|
||||
else:
|
||||
raise RuntimeError(f"Error downloading {filename}")
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
split = "Train" if self.train is True else "Test"
|
||||
return f"Split: {split}"
|
||||
|
||||
|
||||
class FashionMNIST(MNIST):
|
||||
"""`Fashion-MNIST <https://github.com/zalandoresearch/fashion-mnist>`_ Dataset.
|
||||
|
||||
Args:
|
||||
root (string): Root directory of dataset where ``FashionMNIST/raw/train-images-idx3-ubyte``
|
||||
and ``FashionMNIST/raw/t10k-images-idx3-ubyte`` exist.
|
||||
train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``,
|
||||
otherwise from ``t10k-images-idx3-ubyte``.
|
||||
download (bool, optional): If True, downloads the dataset from the internet and
|
||||
puts it in root directory. If dataset is already downloaded, it is not
|
||||
downloaded again.
|
||||
transform (callable, optional): A function/transform that takes in an PIL image
|
||||
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
||||
target_transform (callable, optional): A function/transform that takes in the
|
||||
target and transforms it.
|
||||
"""
|
||||
|
||||
mirrors = ["http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/"]
|
||||
|
||||
resources = [
|
||||
("train-images-idx3-ubyte.gz", "8d4fb7e6c68d591d4c3dfef9ec88bf0d"),
|
||||
("train-labels-idx1-ubyte.gz", "25c81989df183df01b3e8a0aad5dffbe"),
|
||||
("t10k-images-idx3-ubyte.gz", "bef4ecab320f06d8554ea6380940ec79"),
|
||||
("t10k-labels-idx1-ubyte.gz", "bb300cfdad3c16e7a12a480ee83cd310"),
|
||||
]
|
||||
classes = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]
|
||||
|
||||
|
||||
class KMNIST(MNIST):
|
||||
"""`Kuzushiji-MNIST <https://github.com/rois-codh/kmnist>`_ Dataset.
|
||||
|
||||
Args:
|
||||
root (string): Root directory of dataset where ``KMNIST/raw/train-images-idx3-ubyte``
|
||||
and ``KMNIST/raw/t10k-images-idx3-ubyte`` exist.
|
||||
train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``,
|
||||
otherwise from ``t10k-images-idx3-ubyte``.
|
||||
download (bool, optional): If True, downloads the dataset from the internet and
|
||||
puts it in root directory. If dataset is already downloaded, it is not
|
||||
downloaded again.
|
||||
transform (callable, optional): A function/transform that takes in an PIL image
|
||||
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
||||
target_transform (callable, optional): A function/transform that takes in the
|
||||
target and transforms it.
|
||||
"""
|
||||
|
||||
mirrors = ["http://codh.rois.ac.jp/kmnist/dataset/kmnist/"]
|
||||
|
||||
resources = [
|
||||
("train-images-idx3-ubyte.gz", "bdb82020997e1d708af4cf47b453dcf7"),
|
||||
("train-labels-idx1-ubyte.gz", "e144d726b3acfaa3e44228e80efcd344"),
|
||||
("t10k-images-idx3-ubyte.gz", "5c965bf0a639b31b8f53240b1b52f4d7"),
|
||||
("t10k-labels-idx1-ubyte.gz", "7320c461ea6c1c855c0b718fb2a4b134"),
|
||||
]
|
||||
classes = ["o", "ki", "su", "tsu", "na", "ha", "ma", "ya", "re", "wo"]
|
||||
|
||||
|
||||
class EMNIST(MNIST):
|
||||
"""`EMNIST <https://www.westernsydney.edu.au/bens/home/reproducible_research/emnist>`_ Dataset.
|
||||
|
||||
Args:
|
||||
root (string): Root directory of dataset where ``EMNIST/raw/train-images-idx3-ubyte``
|
||||
and ``EMNIST/raw/t10k-images-idx3-ubyte`` exist.
|
||||
split (string): The dataset has 6 different splits: ``byclass``, ``bymerge``,
|
||||
``balanced``, ``letters``, ``digits`` and ``mnist``. This argument specifies
|
||||
which one to use.
|
||||
train (bool, optional): If True, creates dataset from ``training.pt``,
|
||||
otherwise from ``test.pt``.
|
||||
download (bool, optional): If True, downloads the dataset from the internet and
|
||||
puts it in root directory. If dataset is already downloaded, it is not
|
||||
downloaded again.
|
||||
transform (callable, optional): A function/transform that takes in an PIL image
|
||||
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
||||
target_transform (callable, optional): A function/transform that takes in the
|
||||
target and transforms it.
|
||||
"""
|
||||
|
||||
url = "https://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zip"
|
||||
md5 = "58c8d27c78d21e728a6bc7b3cc06412e"
|
||||
splits = ("byclass", "bymerge", "balanced", "letters", "digits", "mnist")
|
||||
# Merged Classes assumes Same structure for both uppercase and lowercase version
|
||||
_merged_classes = {"c", "i", "j", "k", "l", "m", "o", "p", "s", "u", "v", "w", "x", "y", "z"}
|
||||
_all_classes = set(string.digits + string.ascii_letters)
|
||||
classes_split_dict = {
|
||||
"byclass": sorted(list(_all_classes)),
|
||||
"bymerge": sorted(list(_all_classes - _merged_classes)),
|
||||
"balanced": sorted(list(_all_classes - _merged_classes)),
|
||||
"letters": ["N/A"] + list(string.ascii_lowercase),
|
||||
"digits": list(string.digits),
|
||||
"mnist": list(string.digits),
|
||||
}
|
||||
|
||||
def __init__(self, root: str, split: str, **kwargs: Any) -> None:
|
||||
self.split = verify_str_arg(split, "split", self.splits)
|
||||
self.training_file = self._training_file(split)
|
||||
self.test_file = self._test_file(split)
|
||||
super().__init__(root, **kwargs)
|
||||
self.classes = self.classes_split_dict[self.split]
|
||||
|
||||
@staticmethod
|
||||
def _training_file(split) -> str:
|
||||
return f"training_{split}.pt"
|
||||
|
||||
@staticmethod
|
||||
def _test_file(split) -> str:
|
||||
return f"test_{split}.pt"
|
||||
|
||||
@property
|
||||
def _file_prefix(self) -> str:
|
||||
return f"emnist-{self.split}-{'train' if self.train else 'test'}"
|
||||
|
||||
@property
|
||||
def images_file(self) -> str:
|
||||
return os.path.join(self.raw_folder, f"{self._file_prefix}-images-idx3-ubyte")
|
||||
|
||||
@property
|
||||
def labels_file(self) -> str:
|
||||
return os.path.join(self.raw_folder, f"{self._file_prefix}-labels-idx1-ubyte")
|
||||
|
||||
def _load_data(self):
|
||||
return read_image_file(self.images_file), read_label_file(self.labels_file)
|
||||
|
||||
def _check_exists(self) -> bool:
|
||||
return all(check_integrity(file) for file in (self.images_file, self.labels_file))
|
||||
|
||||
def download(self) -> None:
|
||||
"""Download the EMNIST data if it doesn't exist already."""
|
||||
|
||||
if self._check_exists():
|
||||
return
|
||||
|
||||
os.makedirs(self.raw_folder, exist_ok=True)
|
||||
|
||||
download_and_extract_archive(self.url, download_root=self.raw_folder, md5=self.md5)
|
||||
gzip_folder = os.path.join(self.raw_folder, "gzip")
|
||||
for gzip_file in os.listdir(gzip_folder):
|
||||
if gzip_file.endswith(".gz"):
|
||||
extract_archive(os.path.join(gzip_folder, gzip_file), self.raw_folder)
|
||||
shutil.rmtree(gzip_folder)
|
||||
|
||||
|
||||
class QMNIST(MNIST):
|
||||
"""`QMNIST <https://github.com/facebookresearch/qmnist>`_ Dataset.
|
||||
|
||||
Args:
|
||||
root (string): Root directory of dataset whose ``raw``
|
||||
subdir contains binary files of the datasets.
|
||||
what (string,optional): Can be 'train', 'test', 'test10k',
|
||||
'test50k', or 'nist' for respectively the mnist compatible
|
||||
training set, the 60k qmnist testing set, the 10k qmnist
|
||||
examples that match the mnist testing set, the 50k
|
||||
remaining qmnist testing examples, or all the nist
|
||||
digits. The default is to select 'train' or 'test'
|
||||
according to the compatibility argument 'train'.
|
||||
compat (bool,optional): A boolean that says whether the target
|
||||
for each example is class number (for compatibility with
|
||||
the MNIST dataloader) or a torch vector containing the
|
||||
full qmnist information. Default=True.
|
||||
download (bool, optional): If True, downloads the dataset from
|
||||
the internet and puts it in root directory. If dataset is
|
||||
already downloaded, it is not downloaded again.
|
||||
transform (callable, optional): A function/transform that
|
||||
takes in an PIL image and returns a transformed
|
||||
version. E.g, ``transforms.RandomCrop``
|
||||
target_transform (callable, optional): A function/transform
|
||||
that takes in the target and transforms it.
|
||||
train (bool,optional,compatibility): When argument 'what' is
|
||||
not specified, this boolean decides whether to load the
|
||||
training set ot the testing set. Default: True.
|
||||
"""
|
||||
|
||||
subsets = {"train": "train", "test": "test", "test10k": "test", "test50k": "test", "nist": "nist"}
|
||||
resources: Dict[str, List[Tuple[str, str]]] = { # type: ignore[assignment]
|
||||
"train": [
|
||||
(
|
||||
"https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-images-idx3-ubyte.gz",
|
||||
"ed72d4157d28c017586c42bc6afe6370",
|
||||
),
|
||||
(
|
||||
"https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-labels-idx2-int.gz",
|
||||
"0058f8dd561b90ffdd0f734c6a30e5e4",
|
||||
),
|
||||
],
|
||||
"test": [
|
||||
(
|
||||
"https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-images-idx3-ubyte.gz",
|
||||
"1394631089c404de565df7b7aeaf9412",
|
||||
),
|
||||
(
|
||||
"https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-labels-idx2-int.gz",
|
||||
"5b5b05890a5e13444e108efe57b788aa",
|
||||
),
|
||||
],
|
||||
"nist": [
|
||||
(
|
||||
"https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-images-idx3-ubyte.xz",
|
||||
"7f124b3b8ab81486c9d8c2749c17f834",
|
||||
),
|
||||
(
|
||||
"https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-labels-idx2-int.xz",
|
||||
"5ed0e788978e45d4a8bd4b7caec3d79d",
|
||||
),
|
||||
],
|
||||
}
|
||||
classes = [
|
||||
"0 - zero",
|
||||
"1 - one",
|
||||
"2 - two",
|
||||
"3 - three",
|
||||
"4 - four",
|
||||
"5 - five",
|
||||
"6 - six",
|
||||
"7 - seven",
|
||||
"8 - eight",
|
||||
"9 - nine",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self, root: str, what: Optional[str] = None, compat: bool = True, train: bool = True, **kwargs: Any
|
||||
) -> None:
|
||||
if what is None:
|
||||
what = "train" if train else "test"
|
||||
self.what = verify_str_arg(what, "what", tuple(self.subsets.keys()))
|
||||
self.compat = compat
|
||||
self.data_file = what + ".pt"
|
||||
self.training_file = self.data_file
|
||||
self.test_file = self.data_file
|
||||
super().__init__(root, train, **kwargs)
|
||||
|
||||
@property
|
||||
def images_file(self) -> str:
|
||||
(url, _), _ = self.resources[self.subsets[self.what]]
|
||||
return os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0])
|
||||
|
||||
@property
|
||||
def labels_file(self) -> str:
|
||||
_, (url, _) = self.resources[self.subsets[self.what]]
|
||||
return os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0])
|
||||
|
||||
def _check_exists(self) -> bool:
|
||||
return all(check_integrity(file) for file in (self.images_file, self.labels_file))
|
||||
|
||||
def _load_data(self):
|
||||
data = read_sn3_pascalvincent_tensor(self.images_file)
|
||||
if data.dtype != torch.uint8:
|
||||
raise TypeError(f"data should be of dtype torch.uint8 instead of {data.dtype}")
|
||||
if data.ndimension() != 3:
|
||||
raise ValueError("data should have 3 dimensions instead of {data.ndimension()}")
|
||||
|
||||
targets = read_sn3_pascalvincent_tensor(self.labels_file).long()
|
||||
if targets.ndimension() != 2:
|
||||
raise ValueError(f"targets should have 2 dimensions instead of {targets.ndimension()}")
|
||||
|
||||
if self.what == "test10k":
|
||||
data = data[0:10000, :, :].clone()
|
||||
targets = targets[0:10000, :].clone()
|
||||
elif self.what == "test50k":
|
||||
data = data[10000:, :, :].clone()
|
||||
targets = targets[10000:, :].clone()
|
||||
|
||||
return data, targets
|
||||
|
||||
def download(self) -> None:
|
||||
"""Download the QMNIST data if it doesn't exist already.
|
||||
Note that we only download what has been asked for (argument 'what').
|
||||
"""
|
||||
if self._check_exists():
|
||||
return
|
||||
|
||||
os.makedirs(self.raw_folder, exist_ok=True)
|
||||
split = self.resources[self.subsets[self.what]]
|
||||
|
||||
for url, md5 in split:
|
||||
download_and_extract_archive(url, self.raw_folder, md5=md5)
|
||||
|
||||
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
||||
# redefined to handle the compat flag
|
||||
img, target = self.data[index], self.targets[index]
|
||||
img = Image.fromarray(img.numpy(), mode="L")
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
if self.compat:
|
||||
target = int(target[0])
|
||||
if self.target_transform is not None:
|
||||
target = self.target_transform(target)
|
||||
return img, target
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f"Split: {self.what}"
|
||||
|
||||
|
||||
def get_int(b: bytes) -> int:
|
||||
return int(codecs.encode(b, "hex"), 16)
|
||||
|
||||
|
||||
SN3_PASCALVINCENT_BITSMAP = {
|
||||
8: torch.uint8,
|
||||
9: torch.int8,
|
||||
11: torch.int16,
|
||||
12: torch.int32,
|
||||
13: torch.float32,
|
||||
14: torch.float64,
|
||||
}
|
||||
|
||||
TORCH_TYPE_BITS = {
|
||||
torch.uint8: 8,
|
||||
torch.int8: 8,
|
||||
torch.int16: 16,
|
||||
torch.int32: 32,
|
||||
torch.float32: 32,
|
||||
torch.float64: 64,
|
||||
}
|
||||
|
||||
|
||||
def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tensor:
|
||||
"""Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh').
|
||||
Argument may be a filename, compressed filename, or file object.
|
||||
"""
|
||||
# read
|
||||
with open(path, "rb") as f:
|
||||
data = f.read()
|
||||
# parse
|
||||
magic = get_int(data[0:4])
|
||||
nd = magic % 256
|
||||
ty = magic // 256
|
||||
assert 1 <= nd <= 3
|
||||
assert 8 <= ty <= 14
|
||||
torch_type = SN3_PASCALVINCENT_BITSMAP[ty]
|
||||
s = [get_int(data[4 * (i + 1) : 4 * (i + 2)]) for i in range(nd)]
|
||||
|
||||
num_bytes_per_value = TORCH_TYPE_BITS[torch_type] // 8
|
||||
# The MNIST format uses the big endian byte order. If the system uses little endian byte order by default,
|
||||
# we need to reverse the bytes before we can read them with torch.frombuffer().
|
||||
needs_byte_reversal = sys.byteorder == "little" and num_bytes_per_value > 1
|
||||
parsed = torch.frombuffer(bytearray(data), dtype=torch_type, offset=(4 * (nd + 1)))
|
||||
if needs_byte_reversal:
|
||||
parsed = parsed.flip(0)
|
||||
|
||||
assert parsed.shape[0] == np.prod(s) or not strict
|
||||
return parsed.view(*s)
|
||||
|
||||
|
||||
def read_label_file(path: str) -> torch.Tensor:
|
||||
x = read_sn3_pascalvincent_tensor(path, strict=False)
|
||||
if x.dtype != torch.uint8:
|
||||
raise TypeError(f"x should be of dtype torch.uint8 instead of {x.dtype}")
|
||||
if x.ndimension() != 1:
|
||||
raise ValueError(f"x should have 1 dimension instead of {x.ndimension()}")
|
||||
return x.long()
|
||||
|
||||
|
||||
def read_image_file(path: str) -> torch.Tensor:
|
||||
x = read_sn3_pascalvincent_tensor(path, strict=False)
|
||||
if x.dtype != torch.uint8:
|
||||
raise TypeError(f"x should be of dtype torch.uint8 instead of {x.dtype}")
|
||||
if x.ndimension() != 3:
|
||||
raise ValueError(f"x should have 3 dimension instead of {x.ndimension()}")
|
||||
return x
|
|
@ -0,0 +1,522 @@
|
|||
import bz2
|
||||
import contextlib
|
||||
import gzip
|
||||
import hashlib
|
||||
import itertools
|
||||
import lzma
|
||||
import os
|
||||
import os.path
|
||||
import pathlib
|
||||
import re
|
||||
import sys
|
||||
import tarfile
|
||||
import urllib
|
||||
import urllib.error
|
||||
import urllib.request
|
||||
import warnings
|
||||
import zipfile
|
||||
from typing import Any, Callable, Dict, IO, Iterable, Iterator, List, Optional, Tuple, TypeVar
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from .._internally_replaced_utils import _download_file_from_remote_location, _is_remote_location_available
|
||||
|
||||
USER_AGENT = "pytorch/vision"
|
||||
|
||||
|
||||
def _save_response_content(
|
||||
content: Iterator[bytes],
|
||||
destination: str,
|
||||
length: Optional[int] = None,
|
||||
) -> None:
|
||||
with open(destination, "wb") as fh, tqdm(total=length) as pbar:
|
||||
for chunk in content:
|
||||
# filter out keep-alive new chunks
|
||||
if not chunk:
|
||||
continue
|
||||
|
||||
fh.write(chunk)
|
||||
pbar.update(len(chunk))
|
||||
|
||||
|
||||
def _urlretrieve(url: str, filename: str, chunk_size: int = 1024 * 32) -> None:
|
||||
with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response:
|
||||
_save_response_content(iter(lambda: response.read(chunk_size), b""), filename, length=response.length)
|
||||
|
||||
|
||||
def gen_bar_updater() -> Callable[[int, int, int], None]:
|
||||
warnings.warn("The function `gen_bar_update` is deprecated since 0.13 and will be removed in 0.15.")
|
||||
pbar = tqdm(total=None)
|
||||
|
||||
def bar_update(count, block_size, total_size):
|
||||
if pbar.total is None and total_size:
|
||||
pbar.total = total_size
|
||||
progress_bytes = count * block_size
|
||||
pbar.update(progress_bytes - pbar.n)
|
||||
|
||||
return bar_update
|
||||
|
||||
|
||||
def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024) -> str:
|
||||
# Setting the `usedforsecurity` flag does not change anything about the functionality, but indicates that we are
|
||||
# not using the MD5 checksum for cryptography. This enables its usage in restricted environments like FIPS. Without
|
||||
# it torchvision.datasets is unusable in these environments since we perform a MD5 check everywhere.
|
||||
if sys.version_info >= (3, 9):
|
||||
md5 = hashlib.md5(usedforsecurity=False)
|
||||
else:
|
||||
md5 = hashlib.md5()
|
||||
with open(fpath, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(chunk_size), b""):
|
||||
md5.update(chunk)
|
||||
return md5.hexdigest()
|
||||
|
||||
|
||||
def check_md5(fpath: str, md5: str, **kwargs: Any) -> bool:
|
||||
return md5 == calculate_md5(fpath, **kwargs)
|
||||
|
||||
|
||||
def check_integrity(fpath: str, md5: Optional[str] = None) -> bool:
|
||||
if not os.path.isfile(fpath):
|
||||
return False
|
||||
if md5 is None:
|
||||
return True
|
||||
return check_md5(fpath, md5)
|
||||
|
||||
|
||||
def _get_redirect_url(url: str, max_hops: int = 3) -> str:
|
||||
initial_url = url
|
||||
headers = {"Method": "HEAD", "User-Agent": USER_AGENT}
|
||||
|
||||
for _ in range(max_hops + 1):
|
||||
with urllib.request.urlopen(urllib.request.Request(url, headers=headers)) as response:
|
||||
if response.url == url or response.url is None:
|
||||
return url
|
||||
|
||||
url = response.url
|
||||
else:
|
||||
raise RecursionError(
|
||||
f"Request to {initial_url} exceeded {max_hops} redirects. The last redirect points to {url}."
|
||||
)
|
||||
|
||||
|
||||
def _get_google_drive_file_id(url: str) -> Optional[str]:
|
||||
parts = urlparse(url)
|
||||
|
||||
if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None:
|
||||
return None
|
||||
|
||||
match = re.match(r"/file/d/(?P<id>[^/]*)", parts.path)
|
||||
if match is None:
|
||||
return None
|
||||
|
||||
return match.group("id")
|
||||
|
||||
|
||||
def download_url(
|
||||
url: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None, max_redirect_hops: int = 3
|
||||
) -> None:
|
||||
"""Download a file from a url and place it in root.
|
||||
|
||||
Args:
|
||||
url (str): URL to download file from
|
||||
root (str): Directory to place downloaded file in
|
||||
filename (str, optional): Name to save the file under. If None, use the basename of the URL
|
||||
md5 (str, optional): MD5 checksum of the download. If None, do not check
|
||||
max_redirect_hops (int, optional): Maximum number of redirect hops allowed
|
||||
"""
|
||||
root = os.path.expanduser(root)
|
||||
if not filename:
|
||||
filename = os.path.basename(url)
|
||||
fpath = os.path.join(root, filename)
|
||||
|
||||
os.makedirs(root, exist_ok=True)
|
||||
|
||||
# check if file is already present locally
|
||||
if check_integrity(fpath, md5):
|
||||
print("Using downloaded and verified file: " + fpath)
|
||||
return
|
||||
|
||||
if _is_remote_location_available():
|
||||
_download_file_from_remote_location(fpath, url)
|
||||
else:
|
||||
# expand redirect chain if needed
|
||||
url = _get_redirect_url(url, max_hops=max_redirect_hops)
|
||||
|
||||
# check if file is located on Google Drive
|
||||
file_id = _get_google_drive_file_id(url)
|
||||
if file_id is not None:
|
||||
return download_file_from_google_drive(file_id, root, filename, md5)
|
||||
|
||||
# download the file
|
||||
try:
|
||||
print("Downloading " + url + " to " + fpath)
|
||||
_urlretrieve(url, fpath)
|
||||
except (urllib.error.URLError, OSError) as e: # type: ignore[attr-defined]
|
||||
if url[:5] == "https":
|
||||
url = url.replace("https:", "http:")
|
||||
print("Failed download. Trying https -> http instead. Downloading " + url + " to " + fpath)
|
||||
_urlretrieve(url, fpath)
|
||||
else:
|
||||
raise e
|
||||
|
||||
# check integrity of downloaded file
|
||||
if not check_integrity(fpath, md5):
|
||||
raise RuntimeError("File not found or corrupted.")
|
||||
|
||||
|
||||
def list_dir(root: str, prefix: bool = False) -> List[str]:
|
||||
"""List all directories at a given root
|
||||
|
||||
Args:
|
||||
root (str): Path to directory whose folders need to be listed
|
||||
prefix (bool, optional): If true, prepends the path to each result, otherwise
|
||||
only returns the name of the directories found
|
||||
"""
|
||||
root = os.path.expanduser(root)
|
||||
directories = [p for p in os.listdir(root) if os.path.isdir(os.path.join(root, p))]
|
||||
if prefix is True:
|
||||
directories = [os.path.join(root, d) for d in directories]
|
||||
return directories
|
||||
|
||||
|
||||
def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]:
|
||||
"""List all files ending with a suffix at a given root
|
||||
|
||||
Args:
|
||||
root (str): Path to directory whose folders need to be listed
|
||||
suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png').
|
||||
It uses the Python "str.endswith" method and is passed directly
|
||||
prefix (bool, optional): If true, prepends the path to each result, otherwise
|
||||
only returns the name of the files found
|
||||
"""
|
||||
root = os.path.expanduser(root)
|
||||
files = [p for p in os.listdir(root) if os.path.isfile(os.path.join(root, p)) and p.endswith(suffix)]
|
||||
if prefix is True:
|
||||
files = [os.path.join(root, d) for d in files]
|
||||
return files
|
||||
|
||||
|
||||
def _extract_gdrive_api_response(response, chunk_size: int = 32 * 1024) -> Tuple[bytes, Iterator[bytes]]:
|
||||
content = response.iter_content(chunk_size)
|
||||
first_chunk = None
|
||||
# filter out keep-alive new chunks
|
||||
while not first_chunk:
|
||||
first_chunk = next(content)
|
||||
content = itertools.chain([first_chunk], content)
|
||||
|
||||
try:
|
||||
match = re.search("<title>Google Drive - (?P<api_response>.+?)</title>", first_chunk.decode())
|
||||
api_response = match["api_response"] if match is not None else None
|
||||
except UnicodeDecodeError:
|
||||
api_response = None
|
||||
return api_response, content
|
||||
|
||||
|
||||
def download_file_from_google_drive(file_id: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None):
|
||||
"""Download a Google Drive file from and place it in root.
|
||||
|
||||
Args:
|
||||
file_id (str): id of file to be downloaded
|
||||
root (str): Directory to place downloaded file in
|
||||
filename (str, optional): Name to save the file under. If None, use the id of the file.
|
||||
md5 (str, optional): MD5 checksum of the download. If None, do not check
|
||||
"""
|
||||
# Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url
|
||||
|
||||
root = os.path.expanduser(root)
|
||||
if not filename:
|
||||
filename = file_id
|
||||
fpath = os.path.join(root, filename)
|
||||
|
||||
os.makedirs(root, exist_ok=True)
|
||||
|
||||
if check_integrity(fpath, md5):
|
||||
print(f"Using downloaded {'and verified ' if md5 else ''}file: {fpath}")
|
||||
return
|
||||
|
||||
url = "https://drive.google.com/uc"
|
||||
params = dict(id=file_id, export="download")
|
||||
with requests.Session() as session:
|
||||
response = session.get(url, params=params, stream=True)
|
||||
|
||||
for key, value in response.cookies.items():
|
||||
if key.startswith("download_warning"):
|
||||
token = value
|
||||
break
|
||||
else:
|
||||
api_response, content = _extract_gdrive_api_response(response)
|
||||
token = "t" if api_response == "Virus scan warning" else None
|
||||
|
||||
if token is not None:
|
||||
response = session.get(url, params=dict(params, confirm=token), stream=True)
|
||||
api_response, content = _extract_gdrive_api_response(response)
|
||||
|
||||
if api_response == "Quota exceeded":
|
||||
raise RuntimeError(
|
||||
f"The daily quota of the file {filename} is exceeded and it "
|
||||
f"can't be downloaded. This is a limitation of Google Drive "
|
||||
f"and can only be overcome by trying again later."
|
||||
)
|
||||
|
||||
_save_response_content(content, fpath)
|
||||
|
||||
# In case we deal with an unhandled GDrive API response, the file should be smaller than 10kB and contain only text
|
||||
if os.stat(fpath).st_size < 10 * 1024:
|
||||
with contextlib.suppress(UnicodeDecodeError), open(fpath) as fh:
|
||||
text = fh.read()
|
||||
# Regular expression to detect HTML. Copied from https://stackoverflow.com/a/70585604
|
||||
if re.search(r"</?\s*[a-z-][^>]*\s*>|(&(?:[\w\d]+|#\d+|#x[a-f\d]+);)", text):
|
||||
warnings.warn(
|
||||
f"We detected some HTML elements in the downloaded file. "
|
||||
f"This most likely means that the download triggered an unhandled API response by GDrive. "
|
||||
f"Please report this to torchvision at https://github.com/pytorch/vision/issues including "
|
||||
f"the response:\n\n{text}"
|
||||
)
|
||||
|
||||
if md5 and not check_md5(fpath, md5):
|
||||
raise RuntimeError(
|
||||
f"The MD5 checksum of the download file {fpath} does not match the one on record."
|
||||
f"Please delete the file and try again. "
|
||||
f"If the issue persists, please report this to torchvision at https://github.com/pytorch/vision/issues."
|
||||
)
|
||||
|
||||
|
||||
def _extract_tar(from_path: str, to_path: str, compression: Optional[str]) -> None:
|
||||
with tarfile.open(from_path, f"r:{compression[1:]}" if compression else "r") as tar:
|
||||
tar.extractall(to_path)
|
||||
|
||||
|
||||
_ZIP_COMPRESSION_MAP: Dict[str, int] = {
|
||||
".bz2": zipfile.ZIP_BZIP2,
|
||||
".xz": zipfile.ZIP_LZMA,
|
||||
}
|
||||
|
||||
|
||||
def _extract_zip(from_path: str, to_path: str, compression: Optional[str]) -> None:
|
||||
with zipfile.ZipFile(
|
||||
from_path, "r", compression=_ZIP_COMPRESSION_MAP[compression] if compression else zipfile.ZIP_STORED
|
||||
) as zip:
|
||||
zip.extractall(to_path)
|
||||
|
||||
|
||||
_ARCHIVE_EXTRACTORS: Dict[str, Callable[[str, str, Optional[str]], None]] = {
|
||||
".tar": _extract_tar,
|
||||
".zip": _extract_zip,
|
||||
}
|
||||
_COMPRESSED_FILE_OPENERS: Dict[str, Callable[..., IO]] = {
|
||||
".bz2": bz2.open,
|
||||
".gz": gzip.open,
|
||||
".xz": lzma.open,
|
||||
}
|
||||
_FILE_TYPE_ALIASES: Dict[str, Tuple[Optional[str], Optional[str]]] = {
|
||||
".tbz": (".tar", ".bz2"),
|
||||
".tbz2": (".tar", ".bz2"),
|
||||
".tgz": (".tar", ".gz"),
|
||||
}
|
||||
|
||||
|
||||
def _detect_file_type(file: str) -> Tuple[str, Optional[str], Optional[str]]:
|
||||
"""Detect the archive type and/or compression of a file.
|
||||
|
||||
Args:
|
||||
file (str): the filename
|
||||
|
||||
Returns:
|
||||
(tuple): tuple of suffix, archive type, and compression
|
||||
|
||||
Raises:
|
||||
RuntimeError: if file has no suffix or suffix is not supported
|
||||
"""
|
||||
suffixes = pathlib.Path(file).suffixes
|
||||
if not suffixes:
|
||||
raise RuntimeError(
|
||||
f"File '{file}' has no suffixes that could be used to detect the archive type and compression."
|
||||
)
|
||||
suffix = suffixes[-1]
|
||||
|
||||
# check if the suffix is a known alias
|
||||
if suffix in _FILE_TYPE_ALIASES:
|
||||
return (suffix, *_FILE_TYPE_ALIASES[suffix])
|
||||
|
||||
# check if the suffix is an archive type
|
||||
if suffix in _ARCHIVE_EXTRACTORS:
|
||||
return suffix, suffix, None
|
||||
|
||||
# check if the suffix is a compression
|
||||
if suffix in _COMPRESSED_FILE_OPENERS:
|
||||
# check for suffix hierarchy
|
||||
if len(suffixes) > 1:
|
||||
suffix2 = suffixes[-2]
|
||||
|
||||
# check if the suffix2 is an archive type
|
||||
if suffix2 in _ARCHIVE_EXTRACTORS:
|
||||
return suffix2 + suffix, suffix2, suffix
|
||||
|
||||
return suffix, None, suffix
|
||||
|
||||
valid_suffixes = sorted(set(_FILE_TYPE_ALIASES) | set(_ARCHIVE_EXTRACTORS) | set(_COMPRESSED_FILE_OPENERS))
|
||||
raise RuntimeError(f"Unknown compression or archive type: '{suffix}'.\nKnown suffixes are: '{valid_suffixes}'.")
|
||||
|
||||
|
||||
def _decompress(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str:
|
||||
r"""Decompress a file.
|
||||
|
||||
The compression is automatically detected from the file name.
|
||||
|
||||
Args:
|
||||
from_path (str): Path to the file to be decompressed.
|
||||
to_path (str): Path to the decompressed file. If omitted, ``from_path`` without compression extension is used.
|
||||
remove_finished (bool): If ``True``, remove the file after the extraction.
|
||||
|
||||
Returns:
|
||||
(str): Path to the decompressed file.
|
||||
"""
|
||||
suffix, archive_type, compression = _detect_file_type(from_path)
|
||||
if not compression:
|
||||
raise RuntimeError(f"Couldn't detect a compression from suffix {suffix}.")
|
||||
|
||||
if to_path is None:
|
||||
to_path = from_path.replace(suffix, archive_type if archive_type is not None else "")
|
||||
|
||||
# We don't need to check for a missing key here, since this was already done in _detect_file_type()
|
||||
compressed_file_opener = _COMPRESSED_FILE_OPENERS[compression]
|
||||
|
||||
with compressed_file_opener(from_path, "rb") as rfh, open(to_path, "wb") as wfh:
|
||||
wfh.write(rfh.read())
|
||||
|
||||
if remove_finished:
|
||||
os.remove(from_path)
|
||||
|
||||
return to_path
|
||||
|
||||
|
||||
def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str:
|
||||
"""Extract an archive.
|
||||
|
||||
The archive type and a possible compression is automatically detected from the file name. If the file is compressed
|
||||
but not an archive the call is dispatched to :func:`decompress`.
|
||||
|
||||
Args:
|
||||
from_path (str): Path to the file to be extracted.
|
||||
to_path (str): Path to the directory the file will be extracted to. If omitted, the directory of the file is
|
||||
used.
|
||||
remove_finished (bool): If ``True``, remove the file after the extraction.
|
||||
|
||||
Returns:
|
||||
(str): Path to the directory the file was extracted to.
|
||||
"""
|
||||
if to_path is None:
|
||||
to_path = os.path.dirname(from_path)
|
||||
|
||||
suffix, archive_type, compression = _detect_file_type(from_path)
|
||||
if not archive_type:
|
||||
return _decompress(
|
||||
from_path,
|
||||
os.path.join(to_path, os.path.basename(from_path).replace(suffix, "")),
|
||||
remove_finished=remove_finished,
|
||||
)
|
||||
|
||||
# We don't need to check for a missing key here, since this was already done in _detect_file_type()
|
||||
extractor = _ARCHIVE_EXTRACTORS[archive_type]
|
||||
|
||||
extractor(from_path, to_path, compression)
|
||||
if remove_finished:
|
||||
os.remove(from_path)
|
||||
|
||||
return to_path
|
||||
|
||||
|
||||
def download_and_extract_archive(
|
||||
url: str,
|
||||
download_root: str,
|
||||
extract_root: Optional[str] = None,
|
||||
filename: Optional[str] = None,
|
||||
md5: Optional[str] = None,
|
||||
remove_finished: bool = False,
|
||||
) -> None:
|
||||
download_root = os.path.expanduser(download_root)
|
||||
if extract_root is None:
|
||||
extract_root = download_root
|
||||
if not filename:
|
||||
filename = os.path.basename(url)
|
||||
|
||||
download_url(url, download_root, filename, md5)
|
||||
|
||||
archive = os.path.join(download_root, filename)
|
||||
print(f"Extracting {archive} to {extract_root}")
|
||||
extract_archive(archive, extract_root, remove_finished)
|
||||
|
||||
|
||||
def iterable_to_str(iterable: Iterable) -> str:
|
||||
return "'" + "', '".join([str(item) for item in iterable]) + "'"
|
||||
|
||||
|
||||
T = TypeVar("T", str, bytes)
|
||||
|
||||
|
||||
def verify_str_arg(
|
||||
value: T,
|
||||
arg: Optional[str] = None,
|
||||
valid_values: Optional[Iterable[T]] = None,
|
||||
custom_msg: Optional[str] = None,
|
||||
) -> T:
|
||||
if not isinstance(value, torch._six.string_classes):
|
||||
if arg is None:
|
||||
msg = "Expected type str, but got type {type}."
|
||||
else:
|
||||
msg = "Expected type str for argument {arg}, but got type {type}."
|
||||
msg = msg.format(type=type(value), arg=arg)
|
||||
raise ValueError(msg)
|
||||
|
||||
if valid_values is None:
|
||||
return value
|
||||
|
||||
if value not in valid_values:
|
||||
if custom_msg is not None:
|
||||
msg = custom_msg
|
||||
else:
|
||||
msg = "Unknown value '{value}' for argument {arg}. Valid values are {{{valid_values}}}."
|
||||
msg = msg.format(value=value, arg=arg, valid_values=iterable_to_str(valid_values))
|
||||
raise ValueError(msg)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def _read_pfm(file_name: str, slice_channels: int = 2) -> np.ndarray:
|
||||
"""Read file in .pfm format. Might contain either 1 or 3 channels of data.
|
||||
|
||||
Args:
|
||||
file_name (str): Path to the file.
|
||||
slice_channels (int): Number of channels to slice out of the file.
|
||||
Useful for reading different data formats stored in .pfm files: Optical Flows, Stereo Disparity Maps, etc.
|
||||
"""
|
||||
|
||||
with open(file_name, "rb") as f:
|
||||
header = f.readline().rstrip()
|
||||
if header not in [b"PF", b"Pf"]:
|
||||
raise ValueError("Invalid PFM file")
|
||||
|
||||
dim_match = re.match(rb"^(\d+)\s(\d+)\s$", f.readline())
|
||||
if not dim_match:
|
||||
raise Exception("Malformed PFM header.")
|
||||
w, h = (int(dim) for dim in dim_match.groups())
|
||||
|
||||
scale = float(f.readline().rstrip())
|
||||
if scale < 0: # little-endian
|
||||
endian = "<"
|
||||
scale = -scale
|
||||
else:
|
||||
endian = ">" # big-endian
|
||||
|
||||
data = np.fromfile(f, dtype=endian + "f")
|
||||
|
||||
pfm_channels = 3 if header == b"PF" else 1
|
||||
|
||||
data = data.reshape(h, w, pfm_channels).transpose(2, 0, 1)
|
||||
data = np.flip(data, axis=1) # flip on h dimension
|
||||
data = data[:slice_channels, :, :]
|
||||
return data.astype(np.float32)
|
|
@ -0,0 +1,104 @@
|
|||
import os
|
||||
from typing import Any, Callable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.utils.data as data
|
||||
|
||||
from ..utils import _log_api_usage_once
|
||||
|
||||
|
||||
class VisionDataset(data.Dataset):
|
||||
"""
|
||||
Base Class For making datasets which are compatible with torchvision.
|
||||
It is necessary to override the ``__getitem__`` and ``__len__`` method.
|
||||
Args:
|
||||
root (string): Root directory of dataset.
|
||||
transforms (callable, optional): A function/transforms that takes in
|
||||
an image and a label and returns the transformed versions of both.
|
||||
transform (callable, optional): A function/transform that takes in an PIL image
|
||||
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
||||
target_transform (callable, optional): A function/transform that takes in the
|
||||
target and transforms it.
|
||||
.. note::
|
||||
:attr:`transforms` and the combination of :attr:`transform` and :attr:`target_transform` are mutually exclusive.
|
||||
"""
|
||||
|
||||
_repr_indent = 4
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: str,
|
||||
transforms: Optional[Callable] = None,
|
||||
transform: Optional[Callable] = None,
|
||||
target_transform: Optional[Callable] = None,
|
||||
) -> None:
|
||||
self.root = root
|
||||
|
||||
has_transforms = transforms is not None
|
||||
has_separate_transform = transform is not None or target_transform is not None
|
||||
if has_transforms and has_separate_transform:
|
||||
raise ValueError("Only transforms or transform/target_transform can be passed as argument")
|
||||
|
||||
# for backwards-compatibility
|
||||
self.transform = transform
|
||||
self.target_transform = target_transform
|
||||
|
||||
if has_separate_transform:
|
||||
transforms = StandardTransform(transform, target_transform)
|
||||
self.transforms = transforms
|
||||
|
||||
def __getitem__(self, index: int) -> Any:
|
||||
"""
|
||||
Args:
|
||||
index (int): Index
|
||||
Returns:
|
||||
(Any): Sample and meta data, optionally transformed by the respective transforms.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def __len__(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
def __repr__(self) -> str:
|
||||
head = "Dataset " + self.__class__.__name__
|
||||
body = [f"Number of datapoints: {self.__len__()}"]
|
||||
if self.root is not None:
|
||||
body.append(f"Root location: {self.root}")
|
||||
body += self.extra_repr().splitlines()
|
||||
if hasattr(self, "transforms") and self.transforms is not None:
|
||||
body += [repr(self.transforms)]
|
||||
lines = [head] + [" " * self._repr_indent + line for line in body]
|
||||
return "\n".join(lines)
|
||||
|
||||
def _format_transform_repr(self, transform: Callable, head: str) -> List[str]:
|
||||
lines = transform.__repr__().splitlines()
|
||||
return [f"{head}{lines[0]}"] + ["{}{}".format(" " * len(head), line) for line in lines[1:]]
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
class StandardTransform:
|
||||
def __init__(self, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None) -> None:
|
||||
self.transform = transform
|
||||
self.target_transform = target_transform
|
||||
|
||||
def __call__(self, input: Any, target: Any) -> Tuple[Any, Any]:
|
||||
if self.transform is not None:
|
||||
input = self.transform(input)
|
||||
if self.target_transform is not None:
|
||||
target = self.target_transform(target)
|
||||
return input, target
|
||||
|
||||
def _format_transform_repr(self, transform: Callable, head: str) -> List[str]:
|
||||
lines = transform.__repr__().splitlines()
|
||||
return [f"{head}{lines[0]}"] + ["{}{}".format(" " * len(head), line) for line in lines[1:]]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
body = [self.__class__.__name__]
|
||||
if self.transform is not None:
|
||||
body += self._format_transform_repr(self.transform, "Transform: ")
|
||||
if self.target_transform is not None:
|
||||
body += self._format_transform_repr(self.target_transform, "Target transform: ")
|
||||
|
||||
return "\n".join(body)
|
|
@ -0,0 +1 @@
|
|||
from jittor.transform import *
|
|
@ -0,0 +1,582 @@
|
|||
import collections
|
||||
import math
|
||||
import pathlib
|
||||
import warnings
|
||||
from itertools import repeat
|
||||
from types import FunctionType
|
||||
from typing import Any, BinaryIO, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image, ImageColor, ImageDraw, ImageFont
|
||||
|
||||
__all__ = [
|
||||
"make_grid",
|
||||
"save_image",
|
||||
"draw_bounding_boxes",
|
||||
"draw_segmentation_masks",
|
||||
"draw_keypoints",
|
||||
"flow_to_image",
|
||||
]
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def make_grid(
|
||||
tensor: Union[torch.Tensor, List[torch.Tensor]],
|
||||
nrow: int = 8,
|
||||
padding: int = 2,
|
||||
normalize: bool = False,
|
||||
value_range: Optional[Tuple[int, int]] = None,
|
||||
scale_each: bool = False,
|
||||
pad_value: float = 0.0,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Make a grid of images.
|
||||
|
||||
Args:
|
||||
tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W)
|
||||
or a list of images all of the same size.
|
||||
nrow (int, optional): Number of images displayed in each row of the grid.
|
||||
The final grid size is ``(B / nrow, nrow)``. Default: ``8``.
|
||||
padding (int, optional): amount of padding. Default: ``2``.
|
||||
normalize (bool, optional): If True, shift the image to the range (0, 1),
|
||||
by the min and max values specified by ``value_range``. Default: ``False``.
|
||||
value_range (tuple, optional): tuple (min, max) where min and max are numbers,
|
||||
then these numbers are used to normalize the image. By default, min and max
|
||||
are computed from the tensor.
|
||||
range (tuple. optional):
|
||||
.. warning::
|
||||
This parameter was deprecated in ``0.12`` and will be removed in ``0.14``. Please use ``value_range``
|
||||
instead.
|
||||
scale_each (bool, optional): If ``True``, scale each image in the batch of
|
||||
images separately rather than the (min, max) over all images. Default: ``False``.
|
||||
pad_value (float, optional): Value for the padded pixels. Default: ``0``.
|
||||
|
||||
Returns:
|
||||
grid (Tensor): the tensor containing grid of images.
|
||||
"""
|
||||
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
||||
_log_api_usage_once(make_grid)
|
||||
if not torch.is_tensor(tensor):
|
||||
if isinstance(tensor, list):
|
||||
for t in tensor:
|
||||
if not torch.is_tensor(t):
|
||||
raise TypeError(f"tensor or list of tensors expected, got a list containing {type(t)}")
|
||||
else:
|
||||
raise TypeError(f"tensor or list of tensors expected, got {type(tensor)}")
|
||||
|
||||
if "range" in kwargs.keys():
|
||||
warnings.warn(
|
||||
"The parameter 'range' is deprecated since 0.12 and will be removed in 0.14. "
|
||||
"Please use 'value_range' instead."
|
||||
)
|
||||
value_range = kwargs["range"]
|
||||
|
||||
# if list of tensors, convert to a 4D mini-batch Tensor
|
||||
if isinstance(tensor, list):
|
||||
tensor = torch.stack(tensor, dim=0)
|
||||
|
||||
if tensor.dim() == 2: # single image H x W
|
||||
tensor = tensor.unsqueeze(0)
|
||||
if tensor.dim() == 3: # single image
|
||||
if tensor.size(0) == 1: # if single-channel, convert to 3-channel
|
||||
tensor = torch.cat((tensor, tensor, tensor), 0)
|
||||
tensor = tensor.unsqueeze(0)
|
||||
|
||||
if tensor.dim() == 4 and tensor.size(1) == 1: # single-channel images
|
||||
tensor = torch.cat((tensor, tensor, tensor), 1)
|
||||
|
||||
if normalize is True:
|
||||
tensor = tensor.clone() # avoid modifying tensor in-place
|
||||
if value_range is not None and not isinstance(value_range, tuple):
|
||||
raise TypeError("value_range has to be a tuple (min, max) if specified. min and max are numbers")
|
||||
|
||||
def norm_ip(img, low, high):
|
||||
img.clamp_(min=low, max=high)
|
||||
img.sub_(low).div_(max(high - low, 1e-5))
|
||||
|
||||
def norm_range(t, value_range):
|
||||
if value_range is not None:
|
||||
norm_ip(t, value_range[0], value_range[1])
|
||||
else:
|
||||
norm_ip(t, float(t.min()), float(t.max()))
|
||||
|
||||
if scale_each is True:
|
||||
for t in tensor: # loop over mini-batch dimension
|
||||
norm_range(t, value_range)
|
||||
else:
|
||||
norm_range(tensor, value_range)
|
||||
|
||||
if not isinstance(tensor, torch.Tensor):
|
||||
raise TypeError("tensor should be of type torch.Tensor")
|
||||
if tensor.size(0) == 1:
|
||||
return tensor.squeeze(0)
|
||||
|
||||
# make the mini-batch of images into a grid
|
||||
nmaps = tensor.size(0)
|
||||
xmaps = min(nrow, nmaps)
|
||||
ymaps = int(math.ceil(float(nmaps) / xmaps))
|
||||
height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding)
|
||||
num_channels = tensor.size(1)
|
||||
grid = tensor.new_full((num_channels, height * ymaps + padding, width * xmaps + padding), pad_value)
|
||||
k = 0
|
||||
for y in range(ymaps):
|
||||
for x in range(xmaps):
|
||||
if k >= nmaps:
|
||||
break
|
||||
# Tensor.copy_() is a valid method but seems to be missing from the stubs
|
||||
# https://pytorch.org/docs/stable/tensors.html#torch.Tensor.copy_
|
||||
grid.narrow(1, y * height + padding, height - padding).narrow( # type: ignore[attr-defined]
|
||||
2, x * width + padding, width - padding
|
||||
).copy_(tensor[k])
|
||||
k = k + 1
|
||||
return grid
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def save_image(
|
||||
tensor: Union[torch.Tensor, List[torch.Tensor]],
|
||||
fp: Union[str, pathlib.Path, BinaryIO],
|
||||
format: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Save a given Tensor into an image file.
|
||||
|
||||
Args:
|
||||
tensor (Tensor or list): Image to be saved. If given a mini-batch tensor,
|
||||
saves the tensor as a grid of images by calling ``make_grid``.
|
||||
fp (string or file object): A filename or a file object
|
||||
format(Optional): If omitted, the format to use is determined from the filename extension.
|
||||
If a file object was used instead of a filename, this parameter should always be used.
|
||||
**kwargs: Other arguments are documented in ``make_grid``.
|
||||
"""
|
||||
|
||||
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
||||
_log_api_usage_once(save_image)
|
||||
grid = make_grid(tensor, **kwargs)
|
||||
# Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
|
||||
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
|
||||
im = Image.fromarray(ndarr)
|
||||
im.save(fp, format=format)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def draw_bounding_boxes(
|
||||
image: torch.Tensor,
|
||||
boxes: torch.Tensor,
|
||||
labels: Optional[List[str]] = None,
|
||||
colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None,
|
||||
fill: Optional[bool] = False,
|
||||
width: int = 1,
|
||||
font: Optional[str] = None,
|
||||
font_size: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
"""
|
||||
Draws bounding boxes on given image.
|
||||
The values of the input image should be uint8 between 0 and 255.
|
||||
If fill is True, Resulting Tensor should be saved as PNG image.
|
||||
|
||||
Args:
|
||||
image (Tensor): Tensor of shape (C x H x W) and dtype uint8.
|
||||
boxes (Tensor): Tensor of size (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format. Note that
|
||||
the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and
|
||||
`0 <= ymin < ymax < H`.
|
||||
labels (List[str]): List containing the labels of bounding boxes.
|
||||
colors (color or list of colors, optional): List containing the colors
|
||||
of the boxes or single color for all boxes. The color can be represented as
|
||||
PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
|
||||
By default, random colors are generated for boxes.
|
||||
fill (bool): If `True` fills the bounding box with specified color.
|
||||
width (int): Width of bounding box.
|
||||
font (str): A filename containing a TrueType font. If the file is not found in this filename, the loader may
|
||||
also search in other directories, such as the `fonts/` directory on Windows or `/Library/Fonts/`,
|
||||
`/System/Library/Fonts/` and `~/Library/Fonts/` on macOS.
|
||||
font_size (int): The requested font size in points.
|
||||
|
||||
Returns:
|
||||
img (Tensor[C, H, W]): Image Tensor of dtype uint8 with bounding boxes plotted.
|
||||
"""
|
||||
|
||||
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
||||
_log_api_usage_once(draw_bounding_boxes)
|
||||
if not isinstance(image, torch.Tensor):
|
||||
raise TypeError(f"Tensor expected, got {type(image)}")
|
||||
elif image.dtype != torch.uint8:
|
||||
raise ValueError(f"Tensor uint8 expected, got {image.dtype}")
|
||||
elif image.dim() != 3:
|
||||
raise ValueError("Pass individual images, not batches")
|
||||
elif image.size(0) not in {1, 3}:
|
||||
raise ValueError("Only grayscale and RGB images are supported")
|
||||
elif (boxes[:, 0] > boxes[:, 2]).any() or (boxes[:, 1] > boxes[:, 3]).any():
|
||||
raise ValueError(
|
||||
"Boxes need to be in (xmin, ymin, xmax, ymax) format. Use torchvision.ops.box_convert to convert them"
|
||||
)
|
||||
|
||||
num_boxes = boxes.shape[0]
|
||||
|
||||
if num_boxes == 0:
|
||||
warnings.warn("boxes doesn't contain any box. No box was drawn")
|
||||
return image
|
||||
|
||||
if labels is None:
|
||||
labels: Union[List[str], List[None]] = [None] * num_boxes # type: ignore[no-redef]
|
||||
elif len(labels) != num_boxes:
|
||||
raise ValueError(
|
||||
f"Number of boxes ({num_boxes}) and labels ({len(labels)}) mismatch. Please specify labels for each box."
|
||||
)
|
||||
|
||||
if colors is None:
|
||||
colors = _generate_color_palette(num_boxes)
|
||||
elif isinstance(colors, list):
|
||||
if len(colors) < num_boxes:
|
||||
raise ValueError(f"Number of colors ({len(colors)}) is less than number of boxes ({num_boxes}). ")
|
||||
else: # colors specifies a single color for all boxes
|
||||
colors = [colors] * num_boxes
|
||||
|
||||
colors = [(ImageColor.getrgb(color) if isinstance(color, str) else color) for color in colors]
|
||||
|
||||
if font is None:
|
||||
if font_size is not None:
|
||||
warnings.warn("Argument 'font_size' will be ignored since 'font' is not set.")
|
||||
txt_font = ImageFont.load_default()
|
||||
else:
|
||||
txt_font = ImageFont.truetype(font=font, size=font_size or 10)
|
||||
|
||||
# Handle Grayscale images
|
||||
if image.size(0) == 1:
|
||||
image = torch.tile(image, (3, 1, 1))
|
||||
|
||||
ndarr = image.permute(1, 2, 0).cpu().numpy()
|
||||
img_to_draw = Image.fromarray(ndarr)
|
||||
img_boxes = boxes.to(torch.int64).tolist()
|
||||
|
||||
if fill:
|
||||
draw = ImageDraw.Draw(img_to_draw, "RGBA")
|
||||
else:
|
||||
draw = ImageDraw.Draw(img_to_draw)
|
||||
|
||||
for bbox, color, label in zip(img_boxes, colors, labels): # type: ignore[arg-type]
|
||||
if fill:
|
||||
fill_color = color + (100,)
|
||||
draw.rectangle(bbox, width=width, outline=color, fill=fill_color)
|
||||
else:
|
||||
draw.rectangle(bbox, width=width, outline=color)
|
||||
|
||||
if label is not None:
|
||||
margin = width + 1
|
||||
draw.text((bbox[0] + margin, bbox[1] + margin), label, fill=color, font=txt_font)
|
||||
|
||||
return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def draw_segmentation_masks(
|
||||
image: torch.Tensor,
|
||||
masks: torch.Tensor,
|
||||
alpha: float = 0.8,
|
||||
colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
"""
|
||||
Draws segmentation masks on given RGB image.
|
||||
The values of the input image should be uint8 between 0 and 255.
|
||||
|
||||
Args:
|
||||
image (Tensor): Tensor of shape (3, H, W) and dtype uint8.
|
||||
masks (Tensor): Tensor of shape (num_masks, H, W) or (H, W) and dtype bool.
|
||||
alpha (float): Float number between 0 and 1 denoting the transparency of the masks.
|
||||
0 means full transparency, 1 means no transparency.
|
||||
colors (color or list of colors, optional): List containing the colors
|
||||
of the masks or single color for all masks. The color can be represented as
|
||||
PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
|
||||
By default, random colors are generated for each mask.
|
||||
|
||||
Returns:
|
||||
img (Tensor[C, H, W]): Image Tensor, with segmentation masks drawn on top.
|
||||
"""
|
||||
|
||||
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
||||
_log_api_usage_once(draw_segmentation_masks)
|
||||
if not isinstance(image, torch.Tensor):
|
||||
raise TypeError(f"The image must be a tensor, got {type(image)}")
|
||||
elif image.dtype != torch.uint8:
|
||||
raise ValueError(f"The image dtype must be uint8, got {image.dtype}")
|
||||
elif image.dim() != 3:
|
||||
raise ValueError("Pass individual images, not batches")
|
||||
elif image.size()[0] != 3:
|
||||
raise ValueError("Pass an RGB image. Other Image formats are not supported")
|
||||
if masks.ndim == 2:
|
||||
masks = masks[None, :, :]
|
||||
if masks.ndim != 3:
|
||||
raise ValueError("masks must be of shape (H, W) or (batch_size, H, W)")
|
||||
if masks.dtype != torch.bool:
|
||||
raise ValueError(f"The masks must be of dtype bool. Got {masks.dtype}")
|
||||
if masks.shape[-2:] != image.shape[-2:]:
|
||||
raise ValueError("The image and the masks must have the same height and width")
|
||||
|
||||
num_masks = masks.size()[0]
|
||||
if colors is not None and num_masks > len(colors):
|
||||
raise ValueError(f"There are more masks ({num_masks}) than colors ({len(colors)})")
|
||||
|
||||
if num_masks == 0:
|
||||
warnings.warn("masks doesn't contain any mask. No mask was drawn")
|
||||
return image
|
||||
|
||||
if colors is None:
|
||||
colors = _generate_color_palette(num_masks)
|
||||
|
||||
if not isinstance(colors, list):
|
||||
colors = [colors]
|
||||
if not isinstance(colors[0], (tuple, str)):
|
||||
raise ValueError("colors must be a tuple or a string, or a list thereof")
|
||||
if isinstance(colors[0], tuple) and len(colors[0]) != 3:
|
||||
raise ValueError("It seems that you passed a tuple of colors instead of a list of colors")
|
||||
|
||||
out_dtype = torch.uint8
|
||||
|
||||
colors_ = []
|
||||
for color in colors:
|
||||
if isinstance(color, str):
|
||||
color = ImageColor.getrgb(color)
|
||||
colors_.append(torch.tensor(color, dtype=out_dtype))
|
||||
|
||||
img_to_draw = image.detach().clone()
|
||||
# TODO: There might be a way to vectorize this
|
||||
for mask, color in zip(masks, colors_):
|
||||
img_to_draw[:, mask] = color[:, None]
|
||||
|
||||
out = image * (1 - alpha) + img_to_draw * alpha
|
||||
return out.to(out_dtype)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def draw_keypoints(
|
||||
image: torch.Tensor,
|
||||
keypoints: torch.Tensor,
|
||||
connectivity: Optional[List[Tuple[int, int]]] = None,
|
||||
colors: Optional[Union[str, Tuple[int, int, int]]] = None,
|
||||
radius: int = 2,
|
||||
width: int = 3,
|
||||
) -> torch.Tensor:
|
||||
|
||||
"""
|
||||
Draws Keypoints on given RGB image.
|
||||
The values of the input image should be uint8 between 0 and 255.
|
||||
|
||||
Args:
|
||||
image (Tensor): Tensor of shape (3, H, W) and dtype uint8.
|
||||
keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoints location for each of the N instances,
|
||||
in the format [x, y].
|
||||
connectivity (List[Tuple[int, int]]]): A List of tuple where,
|
||||
each tuple contains pair of keypoints to be connected.
|
||||
colors (str, Tuple): The color can be represented as
|
||||
PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
|
||||
radius (int): Integer denoting radius of keypoint.
|
||||
width (int): Integer denoting width of line connecting keypoints.
|
||||
|
||||
Returns:
|
||||
img (Tensor[C, H, W]): Image Tensor of dtype uint8 with keypoints drawn.
|
||||
"""
|
||||
|
||||
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
||||
_log_api_usage_once(draw_keypoints)
|
||||
if not isinstance(image, torch.Tensor):
|
||||
raise TypeError(f"The image must be a tensor, got {type(image)}")
|
||||
elif image.dtype != torch.uint8:
|
||||
raise ValueError(f"The image dtype must be uint8, got {image.dtype}")
|
||||
elif image.dim() != 3:
|
||||
raise ValueError("Pass individual images, not batches")
|
||||
elif image.size()[0] != 3:
|
||||
raise ValueError("Pass an RGB image. Other Image formats are not supported")
|
||||
|
||||
if keypoints.ndim != 3:
|
||||
raise ValueError("keypoints must be of shape (num_instances, K, 2)")
|
||||
|
||||
ndarr = image.permute(1, 2, 0).cpu().numpy()
|
||||
img_to_draw = Image.fromarray(ndarr)
|
||||
draw = ImageDraw.Draw(img_to_draw)
|
||||
img_kpts = keypoints.to(torch.int64).tolist()
|
||||
|
||||
for kpt_id, kpt_inst in enumerate(img_kpts):
|
||||
for inst_id, kpt in enumerate(kpt_inst):
|
||||
x1 = kpt[0] - radius
|
||||
x2 = kpt[0] + radius
|
||||
y1 = kpt[1] - radius
|
||||
y2 = kpt[1] + radius
|
||||
draw.ellipse([x1, y1, x2, y2], fill=colors, outline=None, width=0)
|
||||
|
||||
if connectivity:
|
||||
for connection in connectivity:
|
||||
start_pt_x = kpt_inst[connection[0]][0]
|
||||
start_pt_y = kpt_inst[connection[0]][1]
|
||||
|
||||
end_pt_x = kpt_inst[connection[1]][0]
|
||||
end_pt_y = kpt_inst[connection[1]][1]
|
||||
|
||||
draw.line(
|
||||
((start_pt_x, start_pt_y), (end_pt_x, end_pt_y)),
|
||||
width=width,
|
||||
)
|
||||
|
||||
return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8)
|
||||
|
||||
|
||||
# Flow visualization code adapted from https://github.com/tomrunia/OpticalFlow_Visualization
|
||||
@torch.no_grad()
|
||||
def flow_to_image(flow: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
"""
|
||||
Converts a flow to an RGB image.
|
||||
|
||||
Args:
|
||||
flow (Tensor): Flow of shape (N, 2, H, W) or (2, H, W) and dtype torch.float.
|
||||
|
||||
Returns:
|
||||
img (Tensor): Image Tensor of dtype uint8 where each color corresponds
|
||||
to a given flow direction. Shape is (N, 3, H, W) or (3, H, W) depending on the input.
|
||||
"""
|
||||
|
||||
if flow.dtype != torch.float:
|
||||
raise ValueError(f"Flow should be of dtype torch.float, got {flow.dtype}.")
|
||||
|
||||
orig_shape = flow.shape
|
||||
if flow.ndim == 3:
|
||||
flow = flow[None] # Add batch dim
|
||||
|
||||
if flow.ndim != 4 or flow.shape[1] != 2:
|
||||
raise ValueError(f"Input flow should have shape (2, H, W) or (N, 2, H, W), got {orig_shape}.")
|
||||
|
||||
max_norm = torch.sum(flow**2, dim=1).sqrt().max()
|
||||
epsilon = torch.finfo((flow).dtype).eps
|
||||
normalized_flow = flow / (max_norm + epsilon)
|
||||
img = _normalized_flow_to_image(normalized_flow)
|
||||
|
||||
if len(orig_shape) == 3:
|
||||
img = img[0] # Remove batch dim
|
||||
return img
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
"""
|
||||
Converts a batch of normalized flow to an RGB image.
|
||||
|
||||
Args:
|
||||
normalized_flow (torch.Tensor): Normalized flow tensor of shape (N, 2, H, W)
|
||||
Returns:
|
||||
img (Tensor(N, 3, H, W)): Flow visualization image of dtype uint8.
|
||||
"""
|
||||
|
||||
N, _, H, W = normalized_flow.shape
|
||||
device = normalized_flow.device
|
||||
flow_image = torch.zeros((N, 3, H, W), dtype=torch.uint8, device=device)
|
||||
colorwheel = _make_colorwheel().to(device) # shape [55x3]
|
||||
num_cols = colorwheel.shape[0]
|
||||
norm = torch.sum(normalized_flow**2, dim=1).sqrt()
|
||||
a = torch.atan2(-normalized_flow[:, 1, :, :], -normalized_flow[:, 0, :, :]) / torch.pi
|
||||
fk = (a + 1) / 2 * (num_cols - 1)
|
||||
k0 = torch.floor(fk).to(torch.long)
|
||||
k1 = k0 + 1
|
||||
k1[k1 == num_cols] = 0
|
||||
f = fk - k0
|
||||
|
||||
for c in range(colorwheel.shape[1]):
|
||||
tmp = colorwheel[:, c]
|
||||
col0 = tmp[k0] / 255.0
|
||||
col1 = tmp[k1] / 255.0
|
||||
col = (1 - f) * col0 + f * col1
|
||||
col = 1 - norm * (1 - col)
|
||||
flow_image[:, c, :, :] = torch.floor(255 * col)
|
||||
return flow_image
|
||||
|
||||
|
||||
def _make_colorwheel() -> torch.Tensor:
|
||||
"""
|
||||
Generates a color wheel for optical flow visualization as presented in:
|
||||
Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
|
||||
URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf.
|
||||
|
||||
Returns:
|
||||
colorwheel (Tensor[55, 3]): Colorwheel Tensor.
|
||||
"""
|
||||
|
||||
RY = 15
|
||||
YG = 6
|
||||
GC = 4
|
||||
CB = 11
|
||||
BM = 13
|
||||
MR = 6
|
||||
|
||||
ncols = RY + YG + GC + CB + BM + MR
|
||||
colorwheel = torch.zeros((ncols, 3))
|
||||
col = 0
|
||||
|
||||
# RY
|
||||
colorwheel[0:RY, 0] = 255
|
||||
colorwheel[0:RY, 1] = torch.floor(255 * torch.arange(0, RY) / RY)
|
||||
col = col + RY
|
||||
# YG
|
||||
colorwheel[col : col + YG, 0] = 255 - torch.floor(255 * torch.arange(0, YG) / YG)
|
||||
colorwheel[col : col + YG, 1] = 255
|
||||
col = col + YG
|
||||
# GC
|
||||
colorwheel[col : col + GC, 1] = 255
|
||||
colorwheel[col : col + GC, 2] = torch.floor(255 * torch.arange(0, GC) / GC)
|
||||
col = col + GC
|
||||
# CB
|
||||
colorwheel[col : col + CB, 1] = 255 - torch.floor(255 * torch.arange(CB) / CB)
|
||||
colorwheel[col : col + CB, 2] = 255
|
||||
col = col + CB
|
||||
# BM
|
||||
colorwheel[col : col + BM, 2] = 255
|
||||
colorwheel[col : col + BM, 0] = torch.floor(255 * torch.arange(0, BM) / BM)
|
||||
col = col + BM
|
||||
# MR
|
||||
colorwheel[col : col + MR, 2] = 255 - torch.floor(255 * torch.arange(MR) / MR)
|
||||
colorwheel[col : col + MR, 0] = 255
|
||||
return colorwheel
|
||||
|
||||
|
||||
def _generate_color_palette(num_objects: int):
|
||||
palette = torch.tensor([2**25 - 1, 2**15 - 1, 2**21 - 1])
|
||||
return [tuple((i * palette) % 255) for i in range(num_objects)]
|
||||
|
||||
|
||||
def _log_api_usage_once(obj: Any) -> None:
|
||||
|
||||
"""
|
||||
Logs API usage(module and name) within an organization.
|
||||
In a large ecosystem, it's often useful to track the PyTorch and
|
||||
TorchVision APIs usage. This API provides the similar functionality to the
|
||||
logging module in the Python stdlib. It can be used for debugging purpose
|
||||
to log which methods are used and by default it is inactive, unless the user
|
||||
manually subscribes a logger via the `SetAPIUsageLogger method <https://github.com/pytorch/pytorch/blob/eb3b9fe719b21fae13c7a7cf3253f970290a573e/c10/util/Logging.cpp#L114>`_.
|
||||
Please note it is triggered only once for the same API call within a process.
|
||||
It does not collect any data from open-source users since it is no-op by default.
|
||||
For more information, please refer to
|
||||
* PyTorch note: https://pytorch.org/docs/stable/notes/large_scale_deployments.html#api-usage-logging;
|
||||
* Logging policy: https://github.com/pytorch/vision/issues/5052;
|
||||
|
||||
Args:
|
||||
obj (class instance or method): an object to extract info from.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def _make_ntuple(x: Any, n: int) -> Tuple[Any, ...]:
|
||||
"""
|
||||
Make n-tuple from input x. If x is an iterable, then we just convert it to tuple.
|
||||
Otherwise we will make a tuple of length n, all with value of x.
|
||||
reference: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/utils.py#L8
|
||||
|
||||
Args:
|
||||
x (Any): input value
|
||||
n (int): length of the resulting tuple
|
||||
"""
|
||||
if isinstance(x, collections.abc.Iterable):
|
||||
return tuple(x)
|
||||
return tuple(repeat(x, n))
|
|
@ -457,7 +457,8 @@ def setup_cutt():
|
|||
def install_cutlass(root_folder):
|
||||
# Modified from: https://github.com/ap-hynninen/cutlass
|
||||
# url = "https://cloud.tsinghua.edu.cn/f/171e49e5825549548bc4/?dl=1"
|
||||
url = "https://cg.cs.tsinghua.edu.cn/jittor/assets/cutlass.zip"
|
||||
# url = "https://cg.cs.tsinghua.edu.cn/jittor/assets/cutlass.zip"
|
||||
url = "https://cloud.tsinghua.edu.cn/f/171e49e5825549548bc4/?dl=1"
|
||||
|
||||
filename = "cutlass.zip"
|
||||
fullname = os.path.join(root_folder, filename)
|
||||
|
@ -611,6 +612,26 @@ def setup_nccl():
|
|||
nccl_ops = nccl.ops
|
||||
LOG.vv("Get nccl_ops: "+str(dir(nccl_ops)))
|
||||
|
||||
def setup_hccl():
|
||||
global hccl_ops
|
||||
|
||||
hccl_src_dir = os.path.join(jittor_path, "extern", "acl", "hccl")
|
||||
hccl_src_files = []
|
||||
for r, _, f in os.walk(hccl_src_dir):
|
||||
for fname in f:
|
||||
hccl_src_files.append(os.path.join(r, fname))
|
||||
|
||||
hccl_include_path = os.path.join(os.environ.get("ASCEND_TOOLKIT_HOME"), "aarch64-linux/include/hccl")
|
||||
hccl_lib_name = os.path.join(os.environ.get("ASCEND_TOOLKIT_HOME"), "aarch64-linux/lib64/libhccl.so")
|
||||
ctypes.CDLL(hccl_lib_name, dlopen_flags)
|
||||
|
||||
hccl = compile_custom_ops(hccl_src_files,
|
||||
extra_flags=f" -I\"{hccl_include_path}\" {mpi_compile_flags} ",
|
||||
return_module=True, dlopen_flags=os.RTLD_GLOBAL | os.RTLD_NOW,
|
||||
gen_name_="jittor_hccl_core")
|
||||
hccl_ops = hccl.ops
|
||||
LOG.vv("Get hccl_ops: "+str(dir(hccl_ops)))
|
||||
|
||||
def manual_link(flags):
|
||||
lib_dirs = []
|
||||
libs = []
|
||||
|
@ -708,8 +729,14 @@ cudnn = cublas = curand = cufft = cusparse = None
|
|||
setup_mpi()
|
||||
rank = mpi.world_rank() if in_mpi else 0
|
||||
world_size = mpi.world_size() if in_mpi else 1
|
||||
# if has_acl:
|
||||
# setup_hccl()
|
||||
# elif has_cuda:
|
||||
# setup_nccl()
|
||||
# setup_cutt()
|
||||
# setup_cutlass()
|
||||
|
||||
setup_nccl()
|
||||
|
||||
setup_cutt()
|
||||
setup_cutlass()
|
||||
|
||||
|
|
|
@ -1186,9 +1186,22 @@ make_cache_dir(os.path.join(cache_path, "tmp"))
|
|||
ck_path = os.path.join(cache_path, "checkpoints")
|
||||
make_cache_dir(ck_path)
|
||||
|
||||
|
||||
ascend_toolkit_home = os.getenv('ASCEND_TOOLKIT_HOME')
|
||||
|
||||
# build cache_compile
|
||||
cc_flags += f" -I\"{os.path.join(jittor_path, 'src')}\" "
|
||||
cc_flags += f" -I\"{os.path.join(jittor_path, 'extern')}\" "
|
||||
cc_flags += f" -I\"{os.path.join(ascend_toolkit_home, 'include')}\" "
|
||||
cc_flags += f" -I\"{os.path.join(ascend_toolkit_home, 'include/acl')}\" "
|
||||
cc_flags += f" -I\"{os.path.join(ascend_toolkit_home, 'include/aclnn')}\" "
|
||||
cc_flags += f" -I\"{os.path.join(ascend_toolkit_home, 'include/aclnnop')}\" "
|
||||
cc_flags += f" -L\"{os.path.join(ascend_toolkit_home, 'lib64')}\" "
|
||||
cc_flags += " -llibascendcl "
|
||||
cc_flags += " -llibnnopbase "
|
||||
cc_flags += " -llibopapi "
|
||||
cc_flags += py_include
|
||||
|
||||
check_cache_compile()
|
||||
LOG.v(f"Get cache_compile: {jit_utils.cc}")
|
||||
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,6 +1,6 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2023 Jittor. All Rights Reserved.
|
||||
// Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
// Copyright (c) 2023 Jittor. All Rights Reserved.
|
||||
// Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
@ -11,23 +11,27 @@ using std::unordered_map;
|
|||
|
||||
typedef int aclError;
|
||||
|
||||
static inline unordered_map<aclError,string> gen_map(string s) {
|
||||
unordered_map<aclError,string> smap;
|
||||
for (int i=0; i<s.size(); i++) {
|
||||
if (s[i] == ';') {
|
||||
int j=s.rfind(" ", i);
|
||||
int code = std::stoi(s.substr(j+1, i-j-1));
|
||||
int k = s.rfind(" ", j-1);
|
||||
int l = s.rfind(" ACL_", k-1);
|
||||
smap[code] = s.substr(l+1, k-l-1);
|
||||
static inline unordered_map<aclError, string> gen_map(string s)
|
||||
{
|
||||
unordered_map<aclError, string> smap;
|
||||
for (int i = 0; i < s.size(); i++)
|
||||
{
|
||||
if (s[i] == ';')
|
||||
{
|
||||
int j = s.rfind(" ", i);
|
||||
int code = std::stoi(s.substr(j + 1, i - j - 1));
|
||||
int k = s.rfind(" ", j - 1);
|
||||
int l = s.rfind(" ACL_", k - 1);
|
||||
smap[code] = s.substr(l + 1, k - l - 1);
|
||||
}
|
||||
}
|
||||
return smap;
|
||||
}
|
||||
|
||||
string acl_error_to_string(aclError error) {
|
||||
string acl_error_to_string(aclError error)
|
||||
{
|
||||
|
||||
static unordered_map<aclError,string> acl_error_map = gen_map(R"(
|
||||
static unordered_map<aclError, string> acl_error_map = gen_map(R"(
|
||||
// from acl_base.h
|
||||
static const int ACL_ERROR_INVALID_PARAM = 100000;
|
||||
static const int ACL_ERROR_UNINITIALIZE = 100001;
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2023 Jittor. All Rights Reserved.
|
||||
// Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
// Copyright (c) 2023 Jittor. All Rights Reserved.
|
||||
// Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
@ -10,267 +10,311 @@
|
|||
#include "utils/str_utils.h"
|
||||
#include <chrono>
|
||||
#include <thread>
|
||||
#include "aclnn/aclnn.h"
|
||||
|
||||
namespace jittor {
|
||||
namespace jittor
|
||||
{
|
||||
|
||||
uint64_t acl_jittor_tid;
|
||||
int acl_jittor_thread_running=0;
|
||||
aclrtContext acl_jittor_context;
|
||||
aclrtStream aclstream;
|
||||
uint64_t acl_jittor_tid;
|
||||
int acl_jittor_thread_running = 0;
|
||||
aclrtStream aclstream;
|
||||
void *workspaceAddr = nullptr;
|
||||
uint64_t nowWorkSpaceSize = 0;
|
||||
|
||||
#define CHECK_ACL(x) ASSERTop(x,==,0)
|
||||
#define CHECK_ACL(x) ASSERTop(x, ==, 0)
|
||||
|
||||
static void* acl_jittor_process_callback(void*) {
|
||||
acl_jittor_thread_running = 1;
|
||||
int deviceId = 0;
|
||||
CHECK_ACL(aclrtSetCurrentContext(acl_jittor_context));
|
||||
|
||||
while (acl_jittor_thread_running) {
|
||||
// LOGir << "acl_jittor_process_callback";
|
||||
auto ret = aclrtProcessReport(1000);
|
||||
if (ret) {
|
||||
if (acl_jittor_thread_running && ret != ACL_ERROR_RT_REPORT_TIMEOUT && ret != ACL_ERROR_RT_THREAD_SUBSCRIBE)
|
||||
LOGir << "aclrtProcessReport:" << ret << acl_error_to_string(ret);
|
||||
break;
|
||||
void mallocWorkSpace(uint64_t size)
|
||||
{
|
||||
uint64_t alloc_size = size + 32;
|
||||
alloc_size = ((alloc_size - 1) / 32 + 1) * 32;
|
||||
if (alloc_size > nowWorkSpaceSize)
|
||||
{
|
||||
aclrtFree(workspaceAddr);
|
||||
nowWorkSpaceSize = alloc_size;
|
||||
auto ret = aclrtMalloc(&workspaceAddr, nowWorkSpaceSize, ACL_MEM_MALLOC_HUGE_FIRST);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("allocate workspace failed. ERROR: %d\n", ret); return);
|
||||
}
|
||||
}
|
||||
acl_jittor_thread_running = 0;
|
||||
return (void*)0;
|
||||
}
|
||||
static void *acl_jittor_process_callback(void *)
|
||||
{
|
||||
acl_jittor_thread_running = 1;
|
||||
|
||||
// void aaa(void*) {
|
||||
// LOGir << "haha";
|
||||
// }
|
||||
|
||||
struct acl_jittor_initer {
|
||||
|
||||
acl_jittor_initer() {
|
||||
CHECK_ACL(aclInit(nullptr));
|
||||
uint device_count = 0;
|
||||
// 获取可用的Device数量
|
||||
CHECK_ACL(aclrtGetDeviceCount(&device_count));
|
||||
LOGi << "Found ACL device number:" << device_count;
|
||||
CHECK_ACL(aclrtSetDevice(0));
|
||||
CHECK_ACL(aclrtCreateContext(&acl_jittor_context, 0));
|
||||
CHECK_ACL(aclrtSetCurrentContext(acl_jittor_context));
|
||||
|
||||
pthread_create(&acl_jittor_tid, nullptr, acl_jittor_process_callback, 0);
|
||||
|
||||
// subscribe for default stream
|
||||
CHECK_ACL(aclrtSubscribeReport(acl_jittor_tid,0));
|
||||
|
||||
// simple callback test
|
||||
CHECK_ACL(aclrtCreateStream(&aclstream));
|
||||
// CHECK_ACL(aclrtSubscribeReport(acl_jittor_tid,aclstream));
|
||||
// CHECK_ACL(aclrtLaunchCallback((aclrtCallback)&aaa, 0, ACL_CALLBACK_NO_BLOCK, aclstream));
|
||||
// CHECK_ACL(aclrtLaunchCallback((aclrtCallback)&aaa, 0, ACL_CALLBACK_NO_BLOCK, 0));
|
||||
}
|
||||
|
||||
~acl_jittor_initer() {
|
||||
acl_jittor_thread_running = 0;
|
||||
CHECK_ACL(aclrtUnSubscribeReport(acl_jittor_tid,0));
|
||||
CHECK_ACL(aclrtDestroyContext(acl_jittor_context));
|
||||
CHECK_ACL(aclFinalize());
|
||||
}
|
||||
|
||||
} _acl_jittor_initer;
|
||||
|
||||
string process_acl(const string& src, const string& name, const map<string,string>& kargs) {
|
||||
if (endswith(name, "_jittor.cc"))
|
||||
return src;
|
||||
// static vector<string> dont_compile = {"fp16_emu.cc"};
|
||||
// for (auto& s : dont_compile)
|
||||
// if (endswith(name, s))
|
||||
// return " ";
|
||||
static unordered_set<string> cuda_headers = {
|
||||
"cuda_runtime", "cudnn", "driver_types",
|
||||
"cuda_fp16", "cuda_runtime_api", "fp16_emu",
|
||||
"cudnn_rnn_descriptor", "cublas_v2", "cublas_wrapper",
|
||||
"curand", "curand_wrapper", "cufft", "cufftXt",
|
||||
"CudaUtils", "cutt", "cudnn_wrapper", "cuda_bf16"
|
||||
};
|
||||
static unordered_set<string> fake_class = {
|
||||
"cudnnHandle_t", "cudnnConvolutionBwdFilterAlgo_t",
|
||||
"cudnnConvolutionBwdDataAlgo_t", "cudnnConvolutionFwdAlgo_t",
|
||||
"cufftHandle"
|
||||
};
|
||||
try {
|
||||
auto tokens = token_split(src);
|
||||
int edit = 0;
|
||||
for (int i=0; i<tokens.size(); i++) {
|
||||
auto& token = tokens[i];
|
||||
if (cuda_headers.count(token)) token = "acl_jittor", edit ++; else
|
||||
if (fake_class.count(token)) token = "int", edit ++; else
|
||||
if (token == "CUDA") token = "ACL", edit ++; else
|
||||
if (startswith(token, "cuda")) {
|
||||
if (token.size()>=5 && token[4] >= 'A' && token[4] <= 'Z') {
|
||||
if (token == "cudaGetDeviceCount") {
|
||||
token_replace(tokens, i, "($1);", "((uint*)$1);");
|
||||
} else if (token == "cudaLaunchHostFunc") {
|
||||
// ACL_CALLBACK_BLOCK for 310
|
||||
token_replace(tokens, i, "LaunchHostFunc($1,$2,$3)",
|
||||
"LaunchCallback($2,$3,ACL_CALLBACK_NO_BLOCK,$1)");
|
||||
} else if (token == "cudaMemcpy")
|
||||
token_replace(tokens, i, "cudaMemcpy($1,$2,$3,",
|
||||
"aclrtMemcpy($1,$3,$2,$3,");
|
||||
else if (token == "cudaMemcpyAsync")
|
||||
token_replace(tokens, i, "cudaMemcpyAsync($1,$2,$3,",
|
||||
"aclrtMemcpyAsync($1,$3,$2,$3,");
|
||||
else if (token == "cudaMemcpyDeviceToHost") token = "ACL_MEMCPY_DEVICE_TO_HOST";
|
||||
else if (token == "cudaMemcpyDefault") token = "ACL_MEMCPY_HOST_TO_DEVICE";
|
||||
else if (token == "cudaMemcpyHostToDevice") token = "ACL_MEMCPY_HOST_TO_DEVICE";
|
||||
else if (token == "cudaMemcpyDeviceToDevice") token = "ACL_MEMCPY_DEVICE_TO_DEVICE";
|
||||
else if (token == "cudaMallocManaged" || token == "cudaMalloc") {
|
||||
// unified address not supported
|
||||
token = "aclrtMalloc";
|
||||
token_replace(tokens, i, "($1,$2)",
|
||||
"($1,$2,ACL_MEM_MALLOC_HUGE_FIRST)");
|
||||
} else if (token == "cudaMemGetInfo")
|
||||
token_replace(tokens, i, "cudaMemGetInfo($1,$2)",
|
||||
"aclrtGetMemInfo(ACL_DDR_MEM,$1,$2)");
|
||||
else if (token == "cudaGetLastError")
|
||||
token_replace(tokens, i, "cudaGetLastError()", "0");
|
||||
else if (token == "cudaStreamCreateWithFlags")
|
||||
token_replace(tokens, i-1,
|
||||
"(cudaStreamCreateWithFlags($1,$2));",
|
||||
"(aclrtCreateStream($1)); checkAclErrors(aclrtSubscribeReport(acl_jittor_tid,*$1));");
|
||||
else if (token == "cudaEventCreate")
|
||||
token_replace(tokens, i,
|
||||
"cudaEventCreate($1,$2)",
|
||||
"aclrtCreateEvent($1)");
|
||||
else if (token == "cudaDeviceSynchronize")
|
||||
token = "aclrtSynchronizeDevice";
|
||||
else if (token == "cudaStreamDestroy")
|
||||
token_replace(tokens, i, "cudaStreamDestroy($1)",
|
||||
"(aclrtUnSubscribeReport(acl_jittor_tid,$1), aclrtDestroyStream($1))");
|
||||
else if (token == "cudaEventDestroy")
|
||||
token = "aclrtDestroyEvent";
|
||||
else if (token == "cudaEventRecord")
|
||||
token = "aclrtRecordEvent";
|
||||
else if (token == "cudaStreamWaitEvent")
|
||||
token_replace(tokens, i,
|
||||
"cudaStreamWaitEvent($1,$2,$3)",
|
||||
"aclrtStreamWaitEvent($1,$2)");
|
||||
|
||||
if (token.size() && token[0] == 'c')
|
||||
token = "aclrt" + token.substr(4);
|
||||
if (endswith(token, "_t"))
|
||||
token = token.substr(0, token.size()-2);
|
||||
edit ++;
|
||||
}
|
||||
} else
|
||||
if (token == "_cudaGetErrorEnum") {
|
||||
token_replace(tokens, i, "_cudaGetErrorEnum($1)", "(acl_error_to_string($1))");
|
||||
edit ++;
|
||||
} else
|
||||
if (token == "checkCudaErrors")
|
||||
token = "checkAclErrors";
|
||||
else if (token == "JPU") {
|
||||
edit ++;
|
||||
string new_code;
|
||||
if (tokens[i+2] == "op_compiler")
|
||||
token_replace(tokens, i,
|
||||
"JPU(op_compiler($1,$2,$3))",
|
||||
"acl_jittor_op_compiler($1,$2,$3)");
|
||||
else if (tokens[i+2] == "header")
|
||||
new_code = "#include \"acl_jittor.h\"";
|
||||
if (new_code.size())
|
||||
token_replace(tokens, i, "JPU($1)", new_code);
|
||||
} else if (token == "use_cuda_managed_allocator" && tokens[i+1][0]==',') {
|
||||
tokens[i+2] = "0"; // disable unified address
|
||||
}
|
||||
}
|
||||
if (!edit) return src;
|
||||
string new_src = join(tokens, "");
|
||||
// if (name == "executor.cc") {
|
||||
// new_src = string("#include <Python.h>\n#include <pystate.h>\n#include <common.h>\n")+
|
||||
// "namespace jittor { void acl_op_exec(Op*); }\n" +
|
||||
// replace(new_src, "op->do_run_after_prepare(jkl);",
|
||||
// R"({
|
||||
// acl_op_exec(op);
|
||||
// })");
|
||||
// }
|
||||
if (name == "profiler.cc") {
|
||||
new_src = token_replace_all(new_src, ".cc", ".tikcc");
|
||||
}
|
||||
// LOGir << name << (name == "pass_manager.cc");
|
||||
if (name == "pass_manager.cc") {
|
||||
LOGir << "replace" << name;
|
||||
new_src = token_replace_all(new_src, "run_pass<FloatAtomicFixPass>();", "WTF");
|
||||
}
|
||||
// ????????
|
||||
return new_src;
|
||||
} catch (const std::exception& e) {
|
||||
LOGe << "process acl error:" << e.what();
|
||||
LOGe << "name:" << name;
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
void acl_jittor_op_compiler(string& filename, string& src, bool is_acl, string& extra_flags) {
|
||||
if (!is_acl) return;
|
||||
// extra_flags += " --tik-soc-version=Ascend910 ";
|
||||
// filename = replace(filename, ".cc", ".tikcc");
|
||||
// LOGir << filename;
|
||||
string new_src = process_acl(src, "", {});
|
||||
new_src = replace(new_src, R"(#include "misc/cuda_atomic.h")", "");
|
||||
new_src = replace(new_src, R"(#include "misc/cuda_limits.h")", "");
|
||||
new_src = replace(new_src, "__global__", "__ai_device_entry__");
|
||||
new_src = token_replace_all(new_src, "__launch_bounds__($1)", "");
|
||||
new_src = token_replace_all(new_src, "int thread_num = $1;", "int thread_num = 1;");
|
||||
new_src = token_replace_all(new_src, "tn0=std::max(tn0, $1);", "");
|
||||
new_src = token_replace_all(new_src, "<<<$1>>>", "<<<1,0>>>");
|
||||
new_src = token_replace_all(new_src, "int thread_id = $1;", "int thread_id = 1;");
|
||||
// for inc error
|
||||
new_src = token_replace_all(new_src, "for ($1+=$2)", "for ($1++)");
|
||||
// bit op error
|
||||
new_src = token_replace_all(new_src, "int tnum$1;", "");
|
||||
new_src = token_replace_all(new_src, "int p1$1;", "");
|
||||
new_src = token_replace_all(new_src, "int p2$1;", "");
|
||||
new_src = token_replace_all(new_src, "int tn$1=$2;", "int tn$1=0;");
|
||||
new_src = token_replace_all(new_src, "int tid$1=$2;", "int tid$1=0;");
|
||||
src = new_src;
|
||||
|
||||
new_src = token_replace_all(new_src, "atomicAdd(&$1,$2);", "$1=$1+$2;");
|
||||
// new_src = token_replace_all(new_src, "bool", "int8");
|
||||
new_src = token_replace_all(new_src, "::numeric_min<float32>()", "-1e30");
|
||||
new_src = token_replace_all(new_src, "::numeric_max<float32>()", "1e30");
|
||||
// TODO: support max
|
||||
unordered_map<string,string> opmap = {
|
||||
// {"::max","tikcc::scalar_max"},
|
||||
{"::sqrtf", "tikcc::scalar_sqrt"}
|
||||
};
|
||||
auto ss = split(new_src, ";");
|
||||
for (auto &s : ss) {
|
||||
if (s.find("?") != string::npos) {
|
||||
s = token_replace_all(s+";", "auto $1=$2?$3:$4;", "auto $1=$3;if (!($2)) $1=$4;");
|
||||
}
|
||||
if (s.find("::max") != string::npos) {
|
||||
if (s.find("auto") == string::npos) {
|
||||
s = token_replace_all(s+";", " $1=$4::max($2,$3);", " $1=$2;if ($2 < $3) $1=$3;");
|
||||
} else {
|
||||
s = token_replace_all(s+";", "auto $1=$4::max($2,$3);", "auto $1=$2;if ($2 < $3) $1=$3;");
|
||||
while (acl_jittor_thread_running)
|
||||
{
|
||||
// LOGir << "acl_jittor_process_callback";
|
||||
auto ret = aclrtProcessReport(1000);
|
||||
if (ret)
|
||||
{
|
||||
if (acl_jittor_thread_running && ret != ACL_ERROR_RT_REPORT_TIMEOUT && ret != ACL_ERROR_RT_THREAD_SUBSCRIBE)
|
||||
LOGir << "aclrtProcessReport:" << ret << acl_error_to_string(ret);
|
||||
break;
|
||||
}
|
||||
}
|
||||
for (auto& kv : opmap) {
|
||||
if (s.find(kv.first) != string::npos) {
|
||||
if (s.find("auto") == string::npos) {
|
||||
// $1 = op($2) --> op($1, $2)
|
||||
s = token_replace_all(s+";", " $1= "+kv.first+"($2);", kv.second+"($1, $2);");
|
||||
} else {
|
||||
// auto $1 = op($2) --> float32 $1; op($1, $2);
|
||||
s = token_replace_all(s+";", "auto $1= "+kv.first+"($2);", "float32 $1; " + kv.second+"($1, $2);");
|
||||
acl_jittor_thread_running = 0;
|
||||
return (void *)0;
|
||||
}
|
||||
|
||||
struct acl_jittor_initer
|
||||
{
|
||||
int32_t deviceId;
|
||||
acl_jittor_initer()
|
||||
{
|
||||
CHECK_ACL(aclInit(nullptr));
|
||||
uint device_count = 0;
|
||||
deviceId = 0;
|
||||
// 获取可用的Device数量
|
||||
CHECK_ACL(aclrtGetDeviceCount(&device_count));
|
||||
LOGi << "Found ACL device number:" << device_count;
|
||||
CHECK_ACL(aclrtSetDevice(deviceId));
|
||||
CHECK_ACL(aclrtCreateStream(&aclstream));
|
||||
// pthread_create(&acl_jittor_tid, nullptr, acl_jittor_process_callback, 0);
|
||||
}
|
||||
|
||||
~acl_jittor_initer()
|
||||
{
|
||||
acl_jittor_thread_running = 0;
|
||||
// CHECK_ACL(aclrtUnSubscribeReport(acl_jittor_tid, 0));
|
||||
aclrtDestroyStream(aclstream);
|
||||
aclrtResetDevice(deviceId);
|
||||
CHECK_ACL(aclFinalize());
|
||||
if (nowWorkSpaceSize > 0)
|
||||
{
|
||||
aclrtFree(workspaceAddr);
|
||||
}
|
||||
}
|
||||
|
||||
} _acl_jittor_initer;
|
||||
|
||||
string process_acl(const string &src, const string &name, const map<string, string> &kargs)
|
||||
{
|
||||
if (endswith(name, "_jittor.cc"))
|
||||
return src;
|
||||
// static vector<string> dont_compile = {"fp16_emu.cc"};
|
||||
// for (auto& s : dont_compile)
|
||||
// if (endswith(name, s))
|
||||
// return " ";
|
||||
static unordered_set<string> cuda_headers = {
|
||||
"cuda_runtime", "cudnn", "driver_types",
|
||||
"cuda_fp16", "cuda_runtime_api", "fp16_emu",
|
||||
"cudnn_rnn_descriptor", "cublas_v2", "cublas_wrapper",
|
||||
"curand", "curand_wrapper", "cufft", "cufftXt",
|
||||
"CudaUtils", "cutt", "cudnn_wrapper", "cuda_bf16"};
|
||||
static unordered_set<string> fake_class = {
|
||||
"cudnnHandle_t", "cudnnConvolutionBwdFilterAlgo_t",
|
||||
"cudnnConvolutionBwdDataAlgo_t", "cudnnConvolutionFwdAlgo_t",
|
||||
"cufftHandle"};
|
||||
try
|
||||
{
|
||||
auto tokens = token_split(src);
|
||||
int edit = 0;
|
||||
for (int i = 0; i < tokens.size(); i++)
|
||||
{
|
||||
auto &token = tokens[i];
|
||||
if (cuda_headers.count(token))
|
||||
token = "acl_jittor", edit++;
|
||||
else if (fake_class.count(token))
|
||||
token = "int", edit++;
|
||||
else if (token == "CUDA")
|
||||
token = "ACL", edit++;
|
||||
else if (startswith(token, "cuda"))
|
||||
{
|
||||
if (token.size() >= 5 && token[4] >= 'A' && token[4] <= 'Z')
|
||||
{
|
||||
if (token == "cudaGetDeviceCount")
|
||||
{
|
||||
token_replace(tokens, i, "($1);", "((uint*)$1);");
|
||||
}
|
||||
else if (token == "cudaLaunchHostFunc")
|
||||
{
|
||||
// ACL_CALLBACK_BLOCK for 310
|
||||
token_replace(tokens, i, "LaunchHostFunc($1,$2,$3)",
|
||||
"LaunchCallback($2,$3,ACL_CALLBACK_NO_BLOCK,$1)");
|
||||
}
|
||||
else if (token == "cudaMemcpy")
|
||||
token_replace(tokens, i, "cudaMemcpy($1,$2,$3,",
|
||||
"aclrtMemcpy($1,$3,$2,$3,");
|
||||
else if (token == "cudaMemcpyAsync")
|
||||
token_replace(tokens, i, "cudaMemcpyAsync($1,$2,$3,",
|
||||
"aclrtMemcpyAsync($1,$3,$2,$3,");
|
||||
else if (token == "cudaMemcpyDeviceToHost")
|
||||
token = "ACL_MEMCPY_DEVICE_TO_HOST";
|
||||
else if (token == "cudaMemcpyDefault")
|
||||
token = "ACL_MEMCPY_HOST_TO_DEVICE";
|
||||
else if (token == "cudaMemcpyHostToDevice")
|
||||
token = "ACL_MEMCPY_HOST_TO_DEVICE";
|
||||
else if (token == "cudaMemcpyDeviceToDevice")
|
||||
token = "ACL_MEMCPY_DEVICE_TO_DEVICE";
|
||||
else if (token == "cudaMallocManaged" || token == "cudaMalloc")
|
||||
{
|
||||
// unified address not supported
|
||||
token = "aclrtMalloc";
|
||||
token_replace(tokens, i, "($1,$2)",
|
||||
"($1,$2,ACL_MEM_MALLOC_HUGE_FIRST)");
|
||||
}
|
||||
else if (token == "cudaMemGetInfo")
|
||||
token_replace(tokens, i, "cudaMemGetInfo($1,$2)",
|
||||
"aclrtGetMemInfo(ACL_DDR_MEM,$1,$2)");
|
||||
else if (token == "cudaGetLastError")
|
||||
token_replace(tokens, i, "cudaGetLastError()", "0");
|
||||
else if (token == "cudaStreamCreateWithFlags")
|
||||
token_replace(tokens, i - 1,
|
||||
"(cudaStreamCreateWithFlags($1,$2));",
|
||||
"(aclrtCreateStream($1)); checkAclErrors(aclrtSubscribeReport(acl_jittor_tid,*$1));");
|
||||
else if (token == "cudaEventCreate")
|
||||
token_replace(tokens, i,
|
||||
"cudaEventCreate($1,$2)",
|
||||
"aclrtCreateEvent($1)");
|
||||
else if (token == "cudaDeviceSynchronize")
|
||||
token = "aclrtSynchronizeDevice";
|
||||
else if (token == "cudaStreamDestroy")
|
||||
token_replace(tokens, i, "cudaStreamDestroy($1)",
|
||||
"(aclrtUnSubscribeReport(acl_jittor_tid,$1), aclrtDestroyStream($1))");
|
||||
else if (token == "cudaEventDestroy")
|
||||
token = "aclrtDestroyEvent";
|
||||
else if (token == "cudaEventRecord")
|
||||
token = "aclrtRecordEvent";
|
||||
else if (token == "cudaStreamWaitEvent")
|
||||
token_replace(tokens, i,
|
||||
"cudaStreamWaitEvent($1,$2,$3)",
|
||||
"aclrtStreamWaitEvent($1,$2)");
|
||||
|
||||
if (token.size() && token[0] == 'c')
|
||||
token = "aclrt" + token.substr(4);
|
||||
if (endswith(token, "_t"))
|
||||
token = token.substr(0, token.size() - 2);
|
||||
edit++;
|
||||
}
|
||||
}
|
||||
else if (token == "_cudaGetErrorEnum")
|
||||
{
|
||||
token_replace(tokens, i, "_cudaGetErrorEnum($1)", "(acl_error_to_string($1))");
|
||||
edit++;
|
||||
}
|
||||
else if (token == "checkCudaErrors")
|
||||
token = "checkAclErrors";
|
||||
else if (token == "JPU")
|
||||
{
|
||||
edit++;
|
||||
string new_code;
|
||||
if (tokens[i + 2] == "op_compiler")
|
||||
token_replace(tokens, i,
|
||||
"JPU(op_compiler($1,$2,$3))",
|
||||
"acl_jittor_op_compiler($1,$2,$3)");
|
||||
else if (tokens[i + 2] == "header")
|
||||
new_code = "#include \"acl_jittor.h\"";
|
||||
if (new_code.size())
|
||||
token_replace(tokens, i, "JPU($1)", new_code);
|
||||
}
|
||||
else if (token == "use_cuda_managed_allocator" && tokens[i + 1][0] == ',')
|
||||
{
|
||||
tokens[i + 2] = "0"; // disable unified address
|
||||
}
|
||||
}
|
||||
if (!edit)
|
||||
return src;
|
||||
string new_src = join(tokens, "");
|
||||
// if (name == "executor.cc") {
|
||||
// new_src = string("#include <Python.h>\n#include <pystate.h>\n#include <common.h>\n")+
|
||||
// "namespace jittor { void acl_op_exec(Op*); }\n" +
|
||||
// replace(new_src, "op->do_run_after_prepare(jkl);",
|
||||
// R"({
|
||||
// acl_op_exec(op);
|
||||
// })");
|
||||
// }
|
||||
if (name == "profiler.cc")
|
||||
{
|
||||
new_src = token_replace_all(new_src, ".cc", ".tikcc");
|
||||
}
|
||||
// LOGir << name << (name == "pass_manager.cc");
|
||||
if (name == "pass_manager.cc")
|
||||
{
|
||||
LOGir << "replace" << name;
|
||||
new_src = token_replace_all(new_src, "run_pass<FloatAtomicFixPass>();", "WTF");
|
||||
}
|
||||
// ????????
|
||||
return new_src;
|
||||
}
|
||||
catch (const std::exception &e)
|
||||
{
|
||||
LOGe << "process acl error:" << e.what();
|
||||
LOGe << "name:" << name;
|
||||
throw;
|
||||
}
|
||||
// s = token_replace_all(s+";", "auto $1=$2?$3:$4;", "auto $1=$3;if (!($2)) $1=$4;");
|
||||
// s = token_replace_all(s+";", "auto $1=$2?$3:$4;", "auto $1=$3;if (!($2)) $1=$4;");
|
||||
// if (s.find("::max") != string::npos) {
|
||||
// s = token_replace_all(s+";", " $1= ::max($2);", "tikcc::scalar_max($1, $2);");
|
||||
// }
|
||||
}
|
||||
new_src = join(ss, ";");
|
||||
src = new_src;
|
||||
}
|
||||
|
||||
void acl_jittor_op_compiler(string &filename, string &src, bool is_acl, string &extra_flags)
|
||||
{
|
||||
if (!is_acl)
|
||||
return;
|
||||
string new_src = process_acl(src, "", {});
|
||||
new_src = replace(new_src, R"(#include "misc/cuda_atomic.h")", "");
|
||||
new_src = replace(new_src, R"(#include "misc/cuda_limits.h")", "");
|
||||
new_src = replace(new_src, "__global__", "__ai_device_entry__");
|
||||
new_src = token_replace_all(new_src, "__launch_bounds__($1)", "");
|
||||
new_src = token_replace_all(new_src, "int thread_num = $1;", "int thread_num = 1;");
|
||||
new_src = token_replace_all(new_src, "tn0=std::max(tn0, $1);", "");
|
||||
new_src = token_replace_all(new_src, "<<<$1>>>", "<<<1,0>>>");
|
||||
new_src = token_replace_all(new_src, "int thread_id = $1;", "int thread_id = 1;");
|
||||
// for inc error
|
||||
new_src = token_replace_all(new_src, "for ($1+=$2)", "for ($1++)");
|
||||
// bit op error
|
||||
new_src = token_replace_all(new_src, "int tnum$1;", "");
|
||||
new_src = token_replace_all(new_src, "int p1$1;", "");
|
||||
new_src = token_replace_all(new_src, "int p2$1;", "");
|
||||
new_src = token_replace_all(new_src, "int tn$1=$2;", "int tn$1=0;");
|
||||
new_src = token_replace_all(new_src, "int tid$1=$2;", "int tid$1=0;");
|
||||
src = new_src;
|
||||
|
||||
new_src = token_replace_all(new_src, "atomicAdd(&$1,$2);", "$1=$1+$2;");
|
||||
// new_src = token_replace_all(new_src, "bool", "int8");
|
||||
new_src = token_replace_all(new_src, "::numeric_min<float32>()", "-1e30");
|
||||
new_src = token_replace_all(new_src, "::numeric_max<float32>()", "1e30");
|
||||
// TODO: support max
|
||||
unordered_map<string, string> opmap = {
|
||||
// {"::max","tikcc::scalar_max"},
|
||||
{"::sqrtf", "tikcc::scalar_sqrt"}};
|
||||
auto ss = split(new_src, ";");
|
||||
for (auto &s : ss)
|
||||
{
|
||||
if (s.find("?") != string::npos)
|
||||
{
|
||||
s = token_replace_all(s + ";", "auto $1=$2?$3:$4;", "auto $1=$3;if (!($2)) $1=$4;");
|
||||
}
|
||||
if (s.find("::max") != string::npos)
|
||||
{
|
||||
if (s.find("auto") == string::npos)
|
||||
{
|
||||
s = token_replace_all(s + ";", " $1=$4::max($2,$3);", " $1=$2;if ($2 < $3) $1=$3;");
|
||||
}
|
||||
else
|
||||
{
|
||||
s = token_replace_all(s + ";", "auto $1=$4::max($2,$3);", "auto $1=$2;if ($2 < $3) $1=$3;");
|
||||
}
|
||||
}
|
||||
for (auto &kv : opmap)
|
||||
{
|
||||
if (s.find(kv.first) != string::npos)
|
||||
{
|
||||
if (s.find("auto") == string::npos)
|
||||
{
|
||||
// $1 = op($2) --> op($1, $2)
|
||||
s = token_replace_all(s + ";", " $1= " + kv.first + "($2);", kv.second + "($1, $2);");
|
||||
}
|
||||
else
|
||||
{
|
||||
// auto $1 = op($2) --> float32 $1; op($1, $2);
|
||||
s = token_replace_all(s + ";", "auto $1= " + kv.first + "($2);", "float32 $1; " + kv.second + "($1, $2);");
|
||||
}
|
||||
}
|
||||
}
|
||||
// s = token_replace_all(s+";", "auto $1=$2?$3:$4;", "auto $1=$3;if (!($2)) $1=$4;");
|
||||
// s = token_replace_all(s+";", "auto $1=$2?$3:$4;", "auto $1=$3;if (!($2)) $1=$4;");
|
||||
// if (s.find("::max") != string::npos) {
|
||||
// s = token_replace_all(s+";", " $1= ::max($2);", "tikcc::scalar_max($1, $2);");
|
||||
// }
|
||||
}
|
||||
new_src = join(ss, ";");
|
||||
src = new_src;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2023 Jittor. All Rights Reserved.
|
||||
// Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
// Copyright (c) 2023 Jittor. All Rights Reserved.
|
||||
// Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
@ -10,11 +10,690 @@
|
|||
|
||||
std::string acl_error_to_string(aclError error);
|
||||
|
||||
namespace jittor {
|
||||
namespace jittor
|
||||
{
|
||||
|
||||
EXTERN_LIB uint64_t acl_jittor_tid;
|
||||
EXTERN_LIB aclrtStream aclstream;
|
||||
EXTERN_LIB uint64_t acl_jittor_tid;
|
||||
EXTERN_LIB aclrtStream aclstream;
|
||||
EXTERN_LIB void *workspaceAddr;
|
||||
|
||||
void acl_jittor_op_compiler(string& filename, string& src, bool is_acl, string& extra_flags);
|
||||
void mallocWorkSpace(uint64_t size);
|
||||
|
||||
}
|
||||
void acl_jittor_op_compiler(string &filename, string &src, bool is_acl, string &extra_flags);
|
||||
|
||||
struct AclOpFunctions
|
||||
{
|
||||
// for Unary and Nonzero
|
||||
std::function<aclnnStatus(aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncUnaryNonzero;
|
||||
// for Cast
|
||||
std::function<aclnnStatus(aclTensor *, aclDataType, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncCast;
|
||||
// for Bianry
|
||||
std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncBinary;
|
||||
// for Add and Sub
|
||||
std::function<aclnnStatus(aclTensor *, aclTensor *, aclScalar *, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncAdd;
|
||||
// for Expand, permute, flip
|
||||
std::function<aclnnStatus(aclTensor *, aclIntArray *, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncExpand;
|
||||
// for bmm and matmul
|
||||
std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, int8_t, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncMatmul;
|
||||
// for conv
|
||||
std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, aclIntArray *, aclIntArray *, aclIntArray *, bool, aclIntArray *, int64_t, aclTensor *, int8_t, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncConv;
|
||||
// for reducesum, mean
|
||||
std::function<aclnnStatus(aclTensor *, aclIntArray *, bool, aclDataType, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncReduceSum;
|
||||
// for amax and amin
|
||||
std::function<aclnnStatus(aclTensor *, aclIntArray *, bool, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncAmax;
|
||||
// for conv backward
|
||||
std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, aclIntArray *, aclIntArray *, aclIntArray *, aclIntArray *, bool, aclIntArray *, int, aclBoolArray *, int8_t, aclTensor *, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncConvBackward;
|
||||
// for proddim
|
||||
std::function<aclnnStatus(aclTensor *, float, float, float, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncProdDim;
|
||||
// for select, where
|
||||
std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncSelect;
|
||||
// for random_uniform and random_normal
|
||||
std::function<aclnnStatus(aclTensor *, int64_t, int64_t, int64_t, int64_t, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncRandom;
|
||||
// for maxpool
|
||||
std::function<aclnnStatus(aclTensor *, aclIntArray *, aclIntArray *, aclIntArray *, aclIntArray *, bool, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncMaxPool;
|
||||
// for maxpool backward
|
||||
std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, aclIntArray *, aclIntArray *, aclIntArray *, aclIntArray *, bool, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncMaxPoolBackward;
|
||||
// for avgpool
|
||||
std::function<aclnnStatus(aclTensor *, aclIntArray *, aclIntArray *, aclIntArray *, bool, bool, int64_t, int8_t, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncAvgPool;
|
||||
// for avgpool backward
|
||||
std::function<aclnnStatus(aclTensor *, aclTensor *, aclIntArray *, aclIntArray *, aclIntArray *, bool, bool, int64_t, int8_t, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncAvgPoolBackward;
|
||||
// for concat
|
||||
std::function<aclnnStatus(aclTensorList *, uint64_t, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncConcat;
|
||||
// for gather
|
||||
std::function<aclnnStatus(aclTensor *, uint64_t, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncGather;
|
||||
// for cumsum
|
||||
std::function<aclnnStatus(aclTensor *, uint64_t, aclDataType, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncCumsum;
|
||||
// for scatter
|
||||
std::function<aclnnStatus(aclTensor *, uint64_t, aclTensor *, aclTensor *, uint64_t, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncScatter;
|
||||
// for index
|
||||
std::function<aclnnStatus(aclTensor *, aclTensorList *, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncIndex;
|
||||
// for stridesliceassignv2
|
||||
std::function<aclnnStatus(aclTensor *, aclTensor *, aclIntArray *, aclIntArray *, aclIntArray *, aclIntArray *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncStridedSliceAssignV2;
|
||||
// for slicev2
|
||||
std::function<aclnnStatus(aclTensor *, aclIntArray *, aclIntArray *, aclIntArray *, aclIntArray *, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncSliceV2;
|
||||
// for indexputimpl
|
||||
std::function<aclnnStatus(aclTensor *, aclTensorList *, aclTensor *, bool, bool, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncIndexPutImpl;
|
||||
// for range
|
||||
std::function<aclnnStatus(aclScalar *, aclScalar *, aclScalar *, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncRange;
|
||||
// for leaky_relu
|
||||
std::function<aclnnStatus(aclTensor *, aclScalar *, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncLeakyRelu;
|
||||
// for leaky_relu backward
|
||||
std::function<aclnnStatus(aclTensor *, aclTensor *, aclScalar *, bool, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncLeakyReluBackward;
|
||||
// for dropout
|
||||
std::function<aclnnStatus(aclTensor *, double, bool, int64_t, int64_t, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncDropout;
|
||||
// for dropout backward
|
||||
std::function<aclnnStatus(aclTensor *, aclTensor *, double, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncDropoutBackward;
|
||||
// for split with size
|
||||
std::function<aclnnStatus(aclTensor *, aclIntArray *, int64_t, aclTensorList *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncSplitWithSize;
|
||||
|
||||
// for silu
|
||||
// std::function<aclnnStatus(aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncSilu;
|
||||
|
||||
// for silu backward
|
||||
// std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncSiluBackward;
|
||||
|
||||
// for sigmoid
|
||||
// std::function<aclnnStatus(aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncSigmoid;
|
||||
|
||||
// for sigmoid backward
|
||||
// std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncSigmoidBackward;
|
||||
|
||||
// for embedding
|
||||
// std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncEmbedding;
|
||||
|
||||
// for embedding backward
|
||||
std::function<aclnnStatus(aclTensor *, aclTensor *, uint64_t, uint64_t, bool, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncEmbeddingBackward;
|
||||
|
||||
// for InplaceMaskedScatter MaskedSelect
|
||||
// std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncInplaceMaskedScatter;
|
||||
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, aclrtStream)> executeFunc;
|
||||
|
||||
// for flashattention
|
||||
std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *,
|
||||
aclIntArray *, aclIntArray *, aclIntArray *, double, double, int64_t, int64_t, int64_t, char *, int64_t, int64_t, int64_t,
|
||||
aclTensor *, aclTensor *, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)>
|
||||
getWorkspaceSizeFuncFalshAttention;
|
||||
|
||||
// for flashattention backward
|
||||
std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *,
|
||||
aclIntArray *, aclIntArray *, aclIntArray *, double, double, int64_t, int64_t, int64_t, char *, int64_t, int64_t, int64_t,
|
||||
aclTensor *, aclTensor *, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)>
|
||||
getWorkspaceSizeFuncFalshAttentionBackward;
|
||||
|
||||
// for batchnorm
|
||||
std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *, bool, double, double, aclTensor *, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncBatchNorm;
|
||||
|
||||
// for batchnorm backward
|
||||
std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *, bool, double, aclBoolArray *, aclTensor *, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncBatchNormBackward;
|
||||
|
||||
// for layernorm
|
||||
std::function<aclnnStatus(aclTensor *, aclIntArray *, aclTensor *, aclTensor *, double, aclTensor *, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncLayerNorm;
|
||||
|
||||
// for ROPE
|
||||
std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, aclTensor *, int64_t, uint64_t *, aclOpExecutor **)>
|
||||
getWorkspaceSizeFuncRotaryPosEmb;
|
||||
|
||||
// 添加一个默认构造函数
|
||||
AclOpFunctions() = default;
|
||||
|
||||
// for Unary and Nonzero
|
||||
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
|
||||
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, aclrtStream)> execf)
|
||||
: getWorkspaceSizeFuncUnaryNonzero(gwsf), executeFunc(execf) {}
|
||||
|
||||
// for Cast
|
||||
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclDataType, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
|
||||
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, aclrtStream)> execf)
|
||||
: getWorkspaceSizeFuncCast(gwsf), executeFunc(execf) {}
|
||||
|
||||
// for Binary
|
||||
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
|
||||
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
|
||||
: getWorkspaceSizeFuncBinary(gwsf), executeFunc(execf) {}
|
||||
// for Add and Sub
|
||||
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclTensor *, aclScalar *, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
|
||||
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
|
||||
: getWorkspaceSizeFuncAdd(gwsf), executeFunc(execf) {}
|
||||
|
||||
// for Expand, flip
|
||||
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclIntArray *, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
|
||||
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
|
||||
: getWorkspaceSizeFuncExpand(gwsf), executeFunc(execf) {}
|
||||
|
||||
// for Matmul
|
||||
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, int8_t, uint64_t *, aclOpExecutor **)> gwsf,
|
||||
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
|
||||
: getWorkspaceSizeFuncMatmul(gwsf), executeFunc(execf) {}
|
||||
|
||||
// for conv
|
||||
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, aclIntArray *, aclIntArray *, aclIntArray *, bool, aclIntArray *, int64_t, aclTensor *, int8_t, uint64_t *, aclOpExecutor **)> gwsf,
|
||||
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
|
||||
: getWorkspaceSizeFuncConv(gwsf), executeFunc(execf) {}
|
||||
|
||||
// for reducesum, mean
|
||||
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclIntArray *, bool, aclDataType, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
|
||||
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
|
||||
: getWorkspaceSizeFuncReduceSum(gwsf), executeFunc(execf) {}
|
||||
|
||||
// for amax amin
|
||||
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclIntArray *, bool, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
|
||||
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
|
||||
: getWorkspaceSizeFuncAmax(gwsf), executeFunc(execf) {}
|
||||
|
||||
// for conv backward
|
||||
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, aclIntArray *, aclIntArray *, aclIntArray *, aclIntArray *, bool, aclIntArray *, int, aclBoolArray *, int8_t, aclTensor *, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
|
||||
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
|
||||
: getWorkspaceSizeFuncConvBackward(gwsf), executeFunc(execf) {}
|
||||
|
||||
// for proddim
|
||||
AclOpFunctions(std::function<aclnnStatus(const aclTensor *, float, float, float, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
|
||||
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
|
||||
: getWorkspaceSizeFuncProdDim(gwsf), executeFunc(execf) {}
|
||||
|
||||
// for select, where
|
||||
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
|
||||
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
|
||||
: getWorkspaceSizeFuncSelect(gwsf), executeFunc(execf) {}
|
||||
|
||||
// for random_normal
|
||||
AclOpFunctions(std::function<aclnnStatus(aclTensor *, int64_t, int64_t, int64_t, int64_t, uint64_t *, aclOpExecutor **)> gwsf,
|
||||
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
|
||||
: getWorkspaceSizeFuncRandom(gwsf), executeFunc(execf) {}
|
||||
|
||||
// for maxpool
|
||||
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclIntArray *, aclIntArray *, aclIntArray *, aclIntArray *, bool, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
|
||||
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
|
||||
: getWorkspaceSizeFuncMaxPool(gwsf), executeFunc(execf) {}
|
||||
|
||||
// for maxpool backward
|
||||
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, aclIntArray *, aclIntArray *, aclIntArray *, aclIntArray *, bool, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
|
||||
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
|
||||
: getWorkspaceSizeFuncMaxPoolBackward(gwsf), executeFunc(execf) {}
|
||||
|
||||
// for avgpool
|
||||
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclIntArray *, aclIntArray *, aclIntArray *, bool, bool, int64_t, int8_t, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
|
||||
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
|
||||
: getWorkspaceSizeFuncAvgPool(gwsf), executeFunc(execf) {}
|
||||
|
||||
// for avgpool backward
|
||||
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclTensor *, aclIntArray *, aclIntArray *, aclIntArray *, bool, bool, int64_t, int8_t, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
|
||||
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
|
||||
: getWorkspaceSizeFuncAvgPoolBackward(gwsf), executeFunc(execf) {}
|
||||
|
||||
// for concat
|
||||
AclOpFunctions(std::function<aclnnStatus(aclTensorList *, int64_t, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
|
||||
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
|
||||
: getWorkspaceSizeFuncConcat(gwsf), executeFunc(execf) {}
|
||||
|
||||
// for gather
|
||||
AclOpFunctions(std::function<aclnnStatus(aclTensor *, int64_t, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
|
||||
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
|
||||
: getWorkspaceSizeFuncGather(gwsf), executeFunc(execf) {}
|
||||
|
||||
// for cumsum
|
||||
AclOpFunctions(std::function<aclnnStatus(aclTensor *, int64_t, aclDataType, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
|
||||
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
|
||||
: getWorkspaceSizeFuncCumsum(gwsf), executeFunc(execf) {}
|
||||
|
||||
// for scatter
|
||||
AclOpFunctions(std::function<aclnnStatus(aclTensor *, uint64_t, aclTensor *, aclTensor *, uint64_t, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
|
||||
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
|
||||
: getWorkspaceSizeFuncScatter(gwsf), executeFunc(execf) {}
|
||||
|
||||
// for index
|
||||
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclTensorList *, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
|
||||
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
|
||||
: getWorkspaceSizeFuncIndex(gwsf), executeFunc(execf) {}
|
||||
|
||||
// for stridesliceassignv2
|
||||
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclTensor *, aclIntArray *, aclIntArray *, aclIntArray *, aclIntArray *, uint64_t *, aclOpExecutor **)> gwsf,
|
||||
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
|
||||
: getWorkspaceSizeFuncStridedSliceAssignV2(gwsf), executeFunc(execf) {}
|
||||
|
||||
// for slicev2
|
||||
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclIntArray *, aclIntArray *, aclIntArray *, aclIntArray *, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
|
||||
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
|
||||
: getWorkspaceSizeFuncSliceV2(gwsf), executeFunc(execf) {}
|
||||
|
||||
// for indexputimpl
|
||||
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclTensorList *, aclTensor *, bool, bool, uint64_t *, aclOpExecutor **)> gwsf,
|
||||
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
|
||||
: getWorkspaceSizeFuncIndexPutImpl(gwsf), executeFunc(execf) {}
|
||||
|
||||
// for range
|
||||
AclOpFunctions(std::function<aclnnStatus(aclScalar *, aclScalar *, aclScalar *, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
|
||||
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
|
||||
: getWorkspaceSizeFuncRange(gwsf), executeFunc(execf) {}
|
||||
|
||||
// for leaky_relu
|
||||
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclScalar *, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
|
||||
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
|
||||
: getWorkspaceSizeFuncLeakyRelu(gwsf), executeFunc(execf) {}
|
||||
|
||||
// for leaky_relu backward
|
||||
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclTensor *, aclScalar *, bool, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
|
||||
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
|
||||
: getWorkspaceSizeFuncLeakyReluBackward(gwsf), executeFunc(execf) {}
|
||||
|
||||
// for dropout
|
||||
AclOpFunctions(std::function<aclnnStatus(aclTensor *, double, bool, int64_t, int64_t, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
|
||||
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
|
||||
: getWorkspaceSizeFuncDropout(gwsf), executeFunc(execf) {}
|
||||
|
||||
// for dropout backward
|
||||
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclTensor *, double, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
|
||||
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
|
||||
: getWorkspaceSizeFuncDropoutBackward(gwsf), executeFunc(execf) {}
|
||||
|
||||
// for embedding backward
|
||||
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclTensor *, uint64_t, uint64_t, bool, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
|
||||
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
|
||||
: getWorkspaceSizeFuncEmbeddingBackward(gwsf), executeFunc(execf) {}
|
||||
|
||||
// for split with size
|
||||
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclIntArray *, int64_t, aclTensorList *, uint64_t *, aclOpExecutor **)> gwsf,
|
||||
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
|
||||
: getWorkspaceSizeFuncSplitWithSize(gwsf), executeFunc(execf) {}
|
||||
|
||||
// for flash attention
|
||||
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *,
|
||||
aclIntArray *, aclIntArray *, aclIntArray *, double, double, int64_t, int64_t, int64_t, char *, int64_t, int64_t, int64_t,
|
||||
aclTensor *, aclTensor *, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)>
|
||||
gwsf,
|
||||
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
|
||||
: getWorkspaceSizeFuncFalshAttention(gwsf), executeFunc(execf) {}
|
||||
|
||||
// for flash attention backward
|
||||
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *,
|
||||
aclIntArray *, aclIntArray *, aclIntArray *, double, double, int64_t, int64_t, int64_t, char *, int64_t, int64_t, int64_t,
|
||||
aclTensor *, aclTensor *, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)>
|
||||
gwsf,
|
||||
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
|
||||
: getWorkspaceSizeFuncFalshAttentionBackward(gwsf), executeFunc(execf) {}
|
||||
|
||||
// for batchnorm
|
||||
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *, bool, double, double, aclTensor *, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)>
|
||||
gwsf,
|
||||
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
|
||||
: getWorkspaceSizeFuncBatchNorm(gwsf), executeFunc(execf) {}
|
||||
|
||||
// for batchnorm backward
|
||||
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *, bool, double, aclBoolArray *, aclTensor *, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)>
|
||||
gwsf,
|
||||
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
|
||||
: getWorkspaceSizeFuncBatchNormBackward(gwsf), executeFunc(execf) {}
|
||||
|
||||
// for layernorm
|
||||
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclIntArray *, aclTensor *, aclTensor *, double, aclTensor *, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)>
|
||||
gwsf,
|
||||
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
|
||||
: getWorkspaceSizeFuncLayerNorm(gwsf), executeFunc(execf) {}
|
||||
|
||||
// for ROPE
|
||||
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclTensor *, const aclTensor *, const aclTensor *, int64_t, uint64_t *, aclOpExecutor **)> gwsf,
|
||||
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
|
||||
: getWorkspaceSizeFuncRotaryPosEmb(gwsf), executeFunc(execf) {}
|
||||
};
|
||||
|
||||
static std::unordered_map<std::string, AclOpFunctions> aclOpFuncMap = {
|
||||
{"Abs", AclOpFunctions(aclnnAbsGetWorkspaceSize, aclnnAbs)},
|
||||
{"Exp", AclOpFunctions(aclnnExpGetWorkspaceSize, aclnnExp)},
|
||||
{"Log", AclOpFunctions(aclnnLogGetWorkspaceSize, aclnnLog)},
|
||||
{"Sqrt", AclOpFunctions(aclnnSqrtGetWorkspaceSize, aclnnSqrt)},
|
||||
{"Ceil", AclOpFunctions(aclnnCeilGetWorkspaceSize, aclnnCeil)},
|
||||
{"Floor", AclOpFunctions(aclnnFloorGetWorkspaceSize, aclnnFloor)},
|
||||
{"Round", AclOpFunctions(aclnnRoundGetWorkspaceSize, aclnnRound)},
|
||||
{"Sin", AclOpFunctions(aclnnSinGetWorkspaceSize, aclnnSin)},
|
||||
{"Cos", AclOpFunctions(aclnnCosGetWorkspaceSize, aclnnCos)},
|
||||
{"Tan", AclOpFunctions(aclnnTanGetWorkspaceSize, aclnnTan)},
|
||||
{"Asin", AclOpFunctions(aclnnAsinGetWorkspaceSize, aclnnAsin)},
|
||||
{"Acos", AclOpFunctions(aclnnAcosGetWorkspaceSize, aclnnAcos)},
|
||||
{"Atan", AclOpFunctions(aclnnAtanGetWorkspaceSize, aclnnAtan)},
|
||||
{"Sinh", AclOpFunctions(aclnnSinhGetWorkspaceSize, aclnnSinh)},
|
||||
{"Cosh", AclOpFunctions(aclnnCoshGetWorkspaceSize, aclnnCosh)},
|
||||
{"Tanh", AclOpFunctions(aclnnTanhGetWorkspaceSize, aclnnTanh)},
|
||||
{"Asinh", AclOpFunctions(aclnnAsinhGetWorkspaceSize, aclnnAsinh)},
|
||||
{"Acosh", AclOpFunctions(aclnnAcoshGetWorkspaceSize, aclnnAcosh)},
|
||||
{"Atanh", AclOpFunctions(aclnnAtanhGetWorkspaceSize, aclnnAtanh)},
|
||||
{"Sigmoid", AclOpFunctions(aclnnSigmoidGetWorkspaceSize, aclnnSigmoid)},
|
||||
{"Erf", AclOpFunctions(aclnnErfGetWorkspaceSize, aclnnErf)},
|
||||
{"Erfinv", AclOpFunctions(aclnnErfinvGetWorkspaceSize, aclnnErfinv)},
|
||||
{"LogicalNot", AclOpFunctions(aclnnLogicalNotGetWorkspaceSize, aclnnLogicalNot)},
|
||||
{"BitwiseNot", AclOpFunctions(aclnnBitwiseNotGetWorkspaceSize, aclnnBitwiseNot)},
|
||||
{"Neg", AclOpFunctions(aclnnNegGetWorkspaceSize, aclnnNeg)},
|
||||
{"Cast", AclOpFunctions(aclnnCastGetWorkspaceSize, aclnnCast)},
|
||||
{"Maximum", AclOpFunctions(aclnnMaximumGetWorkspaceSize, aclnnMaximum)},
|
||||
{"Minimum", AclOpFunctions(aclnnMinimumGetWorkspaceSize, aclnnMinimum)},
|
||||
{"Add", AclOpFunctions(aclnnAddGetWorkspaceSize, aclnnAdd)},
|
||||
{"Sub", AclOpFunctions(aclnnSubGetWorkspaceSize, aclnnSub)},
|
||||
{"Mul", AclOpFunctions(aclnnMulGetWorkspaceSize, aclnnMul)},
|
||||
{"RealDiv", AclOpFunctions(aclnnDivGetWorkspaceSize, aclnnDiv)},
|
||||
{"FloorDiv", AclOpFunctions(aclnnFloorDivideGetWorkspaceSize, aclnnFloorDivide)},
|
||||
{"LessEqual", AclOpFunctions(aclnnLeTensorGetWorkspaceSize, aclnnLeTensor)},
|
||||
{"Less", AclOpFunctions(aclnnLtTensorGetWorkspaceSize, aclnnLtTensor)},
|
||||
{"GreaterEqual", AclOpFunctions(aclnnGeTensorGetWorkspaceSize, aclnnGeTensor)},
|
||||
{"Greater", AclOpFunctions(aclnnGtTensorGetWorkspaceSize, aclnnGtTensor)},
|
||||
{"Equal", AclOpFunctions(aclnnEqTensorGetWorkspaceSize, aclnnEqTensor)},
|
||||
{"NotEqual", AclOpFunctions(aclnnNeTensorGetWorkspaceSize, aclnnNeTensor)},
|
||||
{"LogicalAnd", AclOpFunctions(aclnnLogicalAndGetWorkspaceSize, aclnnLogicalAnd)},
|
||||
{"LogicalOr", AclOpFunctions(aclnnLogicalOrGetWorkspaceSize, aclnnLogicalOr)},
|
||||
{"LogicalXor", AclOpFunctions(aclnnLogicalXorGetWorkspaceSize, aclnnLogicalXor)},
|
||||
{"BitwiseAnd", AclOpFunctions(aclnnBitwiseAndTensorGetWorkspaceSize, aclnnBitwiseAndTensor)},
|
||||
{"BitwiseOr", AclOpFunctions(aclnnBitwiseOrTensorGetWorkspaceSize, aclnnBitwiseOrTensor)},
|
||||
{"BitwiseXor", AclOpFunctions(aclnnBitwiseXorTensorGetWorkspaceSize, aclnnBitwiseXorTensor)},
|
||||
{"Pow", AclOpFunctions(aclnnPowTensorTensorGetWorkspaceSize, aclnnPowTensorTensor)},
|
||||
{"Expand", AclOpFunctions(aclnnExpandGetWorkspaceSize, aclnnExpand)},
|
||||
{"MatMul", AclOpFunctions(aclnnMatmulGetWorkspaceSize, aclnnMatmul)},
|
||||
{"BatchMatMul", AclOpFunctions(aclnnBatchMatMulGetWorkspaceSize, aclnnBatchMatMul)},
|
||||
{"ReduceMax", AclOpFunctions(aclnnAmaxGetWorkspaceSize, aclnnAmax)},
|
||||
{"ReduceMin", AclOpFunctions(aclnnAminGetWorkspaceSize, aclnnAmin)},
|
||||
{"ReduceSum", AclOpFunctions(aclnnReduceSumGetWorkspaceSize, aclnnReduceSum)},
|
||||
{"Triu", AclOpFunctions(aclnnTriuGetWorkspaceSize, aclnnTriu)},
|
||||
{"Conv2d", AclOpFunctions(aclnnConvolutionGetWorkspaceSize, aclnnConvolution)},
|
||||
{"Conv2dBackward", AclOpFunctions(aclnnConvolutionBackwardGetWorkspaceSize, aclnnConvolutionBackward)},
|
||||
{"ReduceMean", AclOpFunctions(aclnnMeanGetWorkspaceSize, aclnnMean)},
|
||||
// {"ReduceProd", AclOpFunctions(aclnnProdDimGetWorkspaceSize, aclnnProdDim)},
|
||||
{"Select", AclOpFunctions(aclnnSWhereGetWorkspaceSize, aclnnSWhere)},
|
||||
{"RandomUniform", AclOpFunctions(aclnnInplaceUniformGetWorkspaceSize, aclnnInplaceUniform)},
|
||||
{"RandomNormal", AclOpFunctions(aclnnInplaceNormalGetWorkspaceSize, aclnnInplaceNormal)},
|
||||
{"Transpose", AclOpFunctions(aclnnPermuteGetWorkspaceSize, aclnnPermute)},
|
||||
{"Maxpool", AclOpFunctions(aclnnMaxPool2dWithIndicesGetWorkspaceSize, aclnnMaxPool2dWithIndices)},
|
||||
{"MaxpoolBackward", AclOpFunctions(aclnnMaxPool2dWithIndicesBackwardGetWorkspaceSize, aclnnMaxPool2dWithIndicesBackward)},
|
||||
{"Avgpool", AclOpFunctions(aclnnAvgPool2dGetWorkspaceSize, aclnnAvgPool2d)},
|
||||
{"AvgpoolBackward", AclOpFunctions(aclnnAvgPool2dBackwardGetWorkspaceSize, aclnnAvgPool2dBackward)},
|
||||
{"Flip", AclOpFunctions(aclnnFlipGetWorkspaceSize, aclnnFlip)},
|
||||
{"Concat", AclOpFunctions(aclnnCatGetWorkspaceSize, aclnnCat)},
|
||||
{"Gather", AclOpFunctions(aclnnGatherGetWorkspaceSize, aclnnGather)},
|
||||
{"Cumsum", AclOpFunctions(aclnnCumsumGetWorkspaceSize, aclnnCumsum)},
|
||||
{"Index", AclOpFunctions(aclnnIndexGetWorkspaceSize, aclnnIndex)},
|
||||
{"Scatter", AclOpFunctions(aclnnScatterGetWorkspaceSize, aclnnScatter)},
|
||||
{"Nonzero", AclOpFunctions(aclnnNonzeroGetWorkspaceSize, aclnnNonzero)},
|
||||
{"Where", AclOpFunctions(aclnnSWhereGetWorkspaceSize, aclnnSWhere)},
|
||||
{"Floor", AclOpFunctions(aclnnFloorGetWorkspaceSize, aclnnFloor)},
|
||||
{"StridedSliceAssignV2", AclOpFunctions(aclnnStridedSliceAssignV2GetWorkspaceSize, aclnnStridedSliceAssignV2)},
|
||||
{"SliceV2", AclOpFunctions(aclnnSliceV2GetWorkspaceSize, aclnnSliceV2)},
|
||||
{"IndexPutImpl", AclOpFunctions(aclnnIndexPutImplGetWorkspaceSize, aclnnIndexPutImpl)},
|
||||
{"IndexPutImplAccumulate", AclOpFunctions(aclnnIndexPutImplGetWorkspaceSize, aclnnIndexPutImpl)},
|
||||
{"Range", AclOpFunctions(aclnnRangeGetWorkspaceSize, aclnnRange)},
|
||||
{"ReLU", AclOpFunctions(aclnnReluGetWorkspaceSize, aclnnRelu)},
|
||||
{"LeakyReLU", AclOpFunctions(aclnnLeakyReluGetWorkspaceSize, aclnnLeakyRelu)},
|
||||
{"LeakyReLUBackward", AclOpFunctions(aclnnLeakyReluBackwardGetWorkspaceSize, aclnnLeakyReluBackward)},
|
||||
{"Dropout", AclOpFunctions(aclnnDropoutGetWorkspaceSize, aclnnDropout)},
|
||||
{"DropoutBackward", AclOpFunctions(aclnnDropoutBackwardGetWorkspaceSize, aclnnDropoutBackward)},
|
||||
{"SiLU", AclOpFunctions(aclnnSiluGetWorkspaceSize, aclnnSilu)},
|
||||
{"SiLUBackward", AclOpFunctions(aclnnSiluBackwardGetWorkspaceSize, aclnnSiluBackward)},
|
||||
{"Sigmoid", AclOpFunctions(aclnnSigmoidGetWorkspaceSize, aclnnSigmoid)},
|
||||
{"SigmoidBackward", AclOpFunctions(aclnnSigmoidBackwardGetWorkspaceSize, aclnnSigmoidBackward)},
|
||||
{"Embedding", AclOpFunctions(aclnnEmbeddingGetWorkspaceSize, aclnnEmbedding)},
|
||||
{"EmbeddingBackward", AclOpFunctions(aclnnEmbeddingDenseBackwardGetWorkspaceSize, aclnnEmbeddingDenseBackward)},
|
||||
{"InplaceMaskedScatter", AclOpFunctions(aclnnInplaceMaskedScatterGetWorkspaceSize, aclnnInplaceMaskedScatter)},
|
||||
{"MaskedSelect", AclOpFunctions(aclnnMaskedSelectGetWorkspaceSize, aclnnMaskedSelect)},
|
||||
{"SplitWithSize", AclOpFunctions(aclnnSplitWithSizeGetWorkspaceSize, aclnnSplitWithSize)},
|
||||
{"Softmax", AclOpFunctions(aclnnSoftmaxGetWorkspaceSize, aclnnSoftmax)},
|
||||
{"SoftmaxBackward", AclOpFunctions(aclnnSoftmaxBackwardGetWorkspaceSize, aclnnSoftmaxBackward)},
|
||||
{"FlashAttention", AclOpFunctions(aclnnFlashAttentionScoreV2GetWorkspaceSize, aclnnFlashAttentionScoreV2)},
|
||||
{"FlashAttentionBackward", AclOpFunctions(aclnnFlashAttentionScoreGradV2GetWorkspaceSize, aclnnFlashAttentionScoreGradV2)},
|
||||
{"BatchNorm", AclOpFunctions(aclnnBatchNormGetWorkspaceSize, aclnnBatchNorm)},
|
||||
{"BatchNormBackward", AclOpFunctions(aclnnBatchNormBackwardGetWorkspaceSize, aclnnBatchNormBackward)},
|
||||
{"LayerNorm", AclOpFunctions(aclnnLayerNormGetWorkspaceSize, aclnnLayerNorm)},
|
||||
{"RotaryPosEmb", AclOpFunctions(aclnnApplyRotaryPosEmbGetWorkspaceSize, aclnnApplyRotaryPosEmb)},
|
||||
{"Stack", AclOpFunctions(aclnnStackGetWorkspaceSize, aclnnStack)},
|
||||
{"NanToNum", AclOpFunctions(aclnnNanToNumGetWorkspaceSize, aclnnNanToNum)},
|
||||
};
|
||||
|
||||
struct AclOpAttr
|
||||
{
|
||||
virtual ~AclOpAttr() {}
|
||||
};
|
||||
|
||||
struct ConvAttr : AclOpAttr
|
||||
{
|
||||
vector<int64_t> convStrides;
|
||||
vector<int64_t> convPads;
|
||||
vector<int64_t> convOutPads;
|
||||
vector<int64_t> convDilations;
|
||||
bool convWithBias;
|
||||
bool is_transposed;
|
||||
int64_t group;
|
||||
|
||||
// 析构函数
|
||||
~ConvAttr()
|
||||
{
|
||||
convStrides.clear();
|
||||
convPads.clear();
|
||||
convOutPads.clear();
|
||||
convDilations.clear();
|
||||
}
|
||||
};
|
||||
|
||||
struct ReduceAttr : AclOpAttr
|
||||
{
|
||||
vector<int64_t> axes;
|
||||
// for proddim
|
||||
int64_t prod_dim;
|
||||
bool keepdims;
|
||||
|
||||
~ReduceAttr()
|
||||
{
|
||||
axes.clear();
|
||||
}
|
||||
};
|
||||
|
||||
struct RandomAttr : AclOpAttr
|
||||
{
|
||||
int64_t seed, offset;
|
||||
|
||||
~RandomAttr()
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
struct TriuAttr : AclOpAttr
|
||||
{
|
||||
int64_t diagonal;
|
||||
|
||||
~TriuAttr()
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
struct PoolAttr : AclOpAttr
|
||||
{
|
||||
vector<int64_t> kernel_size;
|
||||
vector<int64_t> poolStrides;
|
||||
vector<int64_t> poolPads;
|
||||
vector<int64_t> poolDilations;
|
||||
bool poolCeil;
|
||||
bool countIncludePad;
|
||||
|
||||
// divisorOverride(const int64_t,计算输入): 表示取平均的除数。数据类型支持INT64。divisorOverride配置为默认值0时表示功能不使能。
|
||||
// https://www.hiascend.com/document/detail/zh/canncommercial/80RC2/apiref/appdevgapi/context/aclnnAvgPool2d.md
|
||||
int64_t divisorOverride = 0;
|
||||
|
||||
// cubeMathType(int8_t,计算输入): host侧的整型,判断Cube单元应该使用哪种计算逻辑进行运算,数据类型支持INT8。对于无特殊说明的数据类型,均保持原始输入数据类型计算。支持的枚举值如下:
|
||||
// 0:KEEP_DTYPE,保持输入的数据类型进行计算。当输入是FLOAT,Atlas 训练系列产品和Atlas 推理系列产品(Ascend 310P处理器)暂不支持,取0时会报错。
|
||||
// 1:ALLOW_FP32_DOWN_PRECISION,允许将输入数据降精度计算。当输入是FLOAT,Atlas 训练系列产品和Atlas 推理系列产品(Ascend 310P处理器)允许转换为FLOAT16计算。
|
||||
// 2:USE_FP16,允许转换为数据类型FLOAT16进行计算。当输入数据类型是FLOAT,转换为FLOAT16计算。
|
||||
// 3:USE_HF32,允许转换为数据类型HFLOAT32计算。当输入是FLOAT,Atlas 训练系列产品、Atlas 推理系列产品(Ascend 310P处理器)和Atlas A2训练系列产品/Atlas 800I A2推理产品暂不支持,取3时会报错。
|
||||
// https://www.hiascend.com/document/detail/zh/canncommercial/80RC2/apiref/appdevgapi/context/aclnnAvgPool2d.md
|
||||
int8_t cubeMathType = 0;
|
||||
|
||||
// 析构函数
|
||||
~PoolAttr()
|
||||
{
|
||||
kernel_size.clear();
|
||||
poolStrides.clear();
|
||||
poolPads.clear();
|
||||
poolDilations.clear();
|
||||
}
|
||||
};
|
||||
|
||||
struct ConcatAttr : AclOpAttr
|
||||
{
|
||||
int64_t tensorNum;
|
||||
int64_t dim;
|
||||
|
||||
~ConcatAttr()
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
struct GatherAttr : AclOpAttr
|
||||
{
|
||||
int64_t dim;
|
||||
|
||||
~GatherAttr()
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
struct ScatterAttr : AclOpAttr
|
||||
{
|
||||
int64_t axis;
|
||||
int64_t reduction;
|
||||
|
||||
~ScatterAttr()
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
struct StrideAttr : AclOpAttr
|
||||
{
|
||||
vector<int64_t> begins;
|
||||
vector<int64_t> ends;
|
||||
vector<int64_t> steps;
|
||||
vector<int64_t> axes;
|
||||
~StrideAttr()
|
||||
{
|
||||
begins.clear();
|
||||
ends.clear();
|
||||
steps.clear();
|
||||
axes.clear();
|
||||
}
|
||||
};
|
||||
|
||||
struct RangeAttr : AclOpAttr
|
||||
{
|
||||
int64_t start;
|
||||
int64_t end;
|
||||
int64_t step;
|
||||
|
||||
~RangeAttr()
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
struct LeakyReluAttr : AclOpAttr
|
||||
{
|
||||
float negativeSlope;
|
||||
bool selfIsResult;
|
||||
|
||||
~LeakyReluAttr()
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
struct DropoutAttr : AclOpAttr
|
||||
{
|
||||
float p;
|
||||
bool train;
|
||||
int64_t seed;
|
||||
int64_t offset;
|
||||
float scale;
|
||||
|
||||
~DropoutAttr()
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
struct EmbeddingAttr : AclOpAttr
|
||||
{
|
||||
int64_t numEmbeddings;
|
||||
// int64_t embeddingDim;
|
||||
int64_t paddingIdx;
|
||||
bool scaleGradByFreq;
|
||||
// bool sparse;
|
||||
// bool isSparse;
|
||||
// bool isDense;
|
||||
|
||||
~EmbeddingAttr()
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
struct SplitWithSizeAttr : AclOpAttr
|
||||
{
|
||||
vector<int64_t> splitSize;
|
||||
int64_t dim;
|
||||
~SplitWithSizeAttr()
|
||||
{
|
||||
splitSize.clear();
|
||||
}
|
||||
};
|
||||
|
||||
struct SoftmaxAttr : AclOpAttr
|
||||
{
|
||||
int64_t dim;
|
||||
~SoftmaxAttr()
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
struct BatchNormAttr : AclOpAttr
|
||||
{
|
||||
bool is_train;
|
||||
float momentum;
|
||||
float eps;
|
||||
~BatchNormAttr()
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
struct LayerNormAttr : AclOpAttr
|
||||
{
|
||||
float eps;
|
||||
vector<int64_t> normalizedShape;
|
||||
int64_t size;
|
||||
~LayerNormAttr()
|
||||
{
|
||||
normalizedShape.clear();
|
||||
}
|
||||
};
|
||||
|
||||
struct FlashAttentionAttr : AclOpAttr
|
||||
{
|
||||
vector<int64_t> prefix;
|
||||
vector<int64_t> qStartIdx;
|
||||
vector<int64_t> kvStartIdx;
|
||||
float scale;
|
||||
float keepProb;
|
||||
int64_t preToken;
|
||||
int64_t nextToken;
|
||||
int64_t headNum;
|
||||
string inputLayout;
|
||||
int64_t innerPrecise;
|
||||
int64_t sparseMode;
|
||||
int64_t psetype;
|
||||
bool hasRealshift;
|
||||
bool hasDropmask;
|
||||
bool hasPaddingmask;
|
||||
bool hasAttentmask;
|
||||
|
||||
~FlashAttentionAttr()
|
||||
{
|
||||
prefix.clear();
|
||||
qStartIdx.clear();
|
||||
kvStartIdx.clear();
|
||||
}
|
||||
};
|
||||
|
||||
struct NanToNumAttr : AclOpAttr
|
||||
{
|
||||
float nan;
|
||||
float posinf;
|
||||
float neginf;
|
||||
~NanToNumAttr()
|
||||
{
|
||||
}
|
||||
};
|
||||
}
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,58 @@
|
|||
#include <iostream>
|
||||
#include <vector>
|
||||
#include "aclnn.h"
|
||||
|
||||
int64_t GetShapeSize(const std::vector<int64_t>& shape) {
|
||||
int64_t shapeSize = 1;
|
||||
for (auto i : shape) {
|
||||
shapeSize *= i;
|
||||
}
|
||||
return shapeSize;
|
||||
}
|
||||
|
||||
void PrintOutResult(std::vector<int64_t> &shape, void** deviceAddr) {
|
||||
auto size = GetShapeSize(shape);
|
||||
std::vector<int> resultData(size, 0);
|
||||
auto ret = aclrtMemcpy(resultData.data(), resultData.size() * sizeof(resultData[0]),
|
||||
*deviceAddr, size * sizeof(resultData[0]), ACL_MEMCPY_DEVICE_TO_HOST);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("copy result from device to host failed. ERROR: %d\n", ret); return);
|
||||
for (int64_t i = 0; i < size; i++) {
|
||||
LOG_PRINT("mean result[%ld] is: %d\n", i, resultData[i]);
|
||||
}
|
||||
}
|
||||
|
||||
/*int Init(int32_t deviceId) {
|
||||
// 固定写法,AscendCL初始化
|
||||
auto ret = aclInit(nullptr);
|
||||
CHECK_RET(ret == ACL_SUCCESS or ret == 100002, LOG_PRINT("aclInit failed. ERROR: %d\n", ret); return ret);
|
||||
ret = aclrtSetDevice(deviceId);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtSetDevice failed. ERROR: %d\n", ret); return ret);
|
||||
//ret = aclrtCreateStream(stream);
|
||||
//CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtCreateStream failed. ERROR: %d\n", ret); return ret);
|
||||
return 0;
|
||||
}*/
|
||||
|
||||
/*
|
||||
template <typename T>
|
||||
int CreateAclTensor(const std::vector<T>& hostData, const std::vector<int64_t>& shape, void** deviceAddr,
|
||||
aclDataType dataType, aclTensor** tensor) {
|
||||
auto size = GetShapeSize(shape) * sizeof(T);
|
||||
// 调用aclrtMalloc申请device侧内存
|
||||
auto ret = aclrtMalloc(deviceAddr, size, ACL_MEM_MALLOC_HUGE_FIRST);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMalloc failed. ERROR: %d\n", ret); return ret);
|
||||
// 调用aclrtMemcpy将host侧数据拷贝到device侧内存上
|
||||
ret = aclrtMemcpy(*deviceAddr, size, hostData.data(), size, ACL_MEMCPY_HOST_TO_DEVICE);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMemcpy failed. ERROR: %d\n", ret); return ret);
|
||||
|
||||
// 计算连续tensor的strides
|
||||
std::vector<int64_t> strides(shape.size(), 1);
|
||||
for (int64_t i = shape.size() - 2; i >= 0; i--) {
|
||||
strides[i] = shape[i + 1] * strides[i + 1];
|
||||
}
|
||||
|
||||
// 调用aclCreateTensor接口创建aclTensor
|
||||
*tensor = aclCreateTensor(shape.data(), shape.size(), dataType, strides.data(), 0, aclFormat::ACL_FORMAT_ND,
|
||||
shape.data(), shape.size(), *deviceAddr);
|
||||
return 0;
|
||||
}*/
|
||||
|
|
@ -0,0 +1,134 @@
|
|||
#include <iostream>
|
||||
#include <vector>
|
||||
#include "acl.h"
|
||||
// unary
|
||||
#include "aclnnop/aclnn_abs.h"
|
||||
#include "aclnnop/aclnn_neg.h"
|
||||
#include "aclnnop/aclnn_exp.h"
|
||||
#include "aclnnop/aclnn_log.h"
|
||||
#include "aclnnop/aclnn_sqrt.h"
|
||||
#include "aclnnop/aclnn_ceil.h"
|
||||
#include "aclnnop/aclnn_floor.h"
|
||||
#include "aclnnop/aclnn_round.h"
|
||||
#include "aclnnop/aclnn_sin.h"
|
||||
#include "aclnnop/aclnn_cos.h"
|
||||
#include "aclnnop/aclnn_tan.h"
|
||||
#include "aclnnop/aclnn_asin.h"
|
||||
#include "aclnnop/aclnn_acos.h"
|
||||
#include "aclnnop/aclnn_atan.h"
|
||||
#include "aclnnop/aclnn_sinh.h"
|
||||
#include "aclnnop/aclnn_cosh.h"
|
||||
#include "aclnnop/aclnn_tanh.h"
|
||||
#include "aclnnop/aclnn_asinh.h"
|
||||
#include "aclnnop/aclnn_acosh.h"
|
||||
#include "aclnnop/aclnn_atanh.h"
|
||||
#include "aclnnop/aclnn_sigmoid.h"
|
||||
#include "aclnnop/aclnn_erf.h"
|
||||
#include "aclnnop/aclnn_erfinv.h"
|
||||
#include "aclnnop/aclnn_logical_not.h"
|
||||
#include "aclnnop/aclnn_bitwise_not.h"
|
||||
#include "aclnnop/aclnn_cast.h"
|
||||
#include "aclnnop/aclnn_nonzero.h"
|
||||
// binary
|
||||
#include "aclnnop/aclnn_maximum.h"
|
||||
#include "aclnnop/aclnn_minimum.h"
|
||||
#include "aclnnop/aclnn_add.h"
|
||||
#include "aclnnop/aclnn_sub.h"
|
||||
#include "aclnnop/aclnn_mul.h"
|
||||
#include "aclnnop/aclnn_div.h"
|
||||
#include "aclnnop/aclnn_floor_divide.h"
|
||||
#include "aclnnop/aclnn_le_tensor.h"
|
||||
#include "aclnnop/aclnn_lt_tensor.h"
|
||||
#include "aclnnop/aclnn_ge_tensor.h"
|
||||
#include "aclnnop/aclnn_gt_tensor.h"
|
||||
#include "aclnnop/aclnn_eq_tensor.h"
|
||||
#include "aclnnop/aclnn_ne_tensor.h"
|
||||
#include "aclnnop/aclnn_logical_and.h"
|
||||
#include "aclnnop/aclnn_logical_or.h"
|
||||
#include "aclnnop/aclnn_logical_xor.h"
|
||||
#include "aclnnop/aclnn_bitwise_and_tensor.h"
|
||||
#include "aclnnop/aclnn_bitwise_or_tensor.h"
|
||||
#include "aclnnop/aclnn_bitwise_xor_tensor.h"
|
||||
#include "aclnnop/aclnn_pow_tensor_tensor.h"
|
||||
#include "aclnnop/aclnn_expand.h"
|
||||
#include "aclnnop/aclnn_matmul.h"
|
||||
#include "aclnnop/aclnn_batch_matmul.h"
|
||||
#include "aclnnop/aclnn_convolution.h"
|
||||
#include "aclnnop/aclnn_convolution_backward.h"
|
||||
#include "aclnnop/aclnn_reduce_sum.h"
|
||||
#include "aclnnop/aclnn_amax.h"
|
||||
#include "aclnnop/aclnn_amin.h"
|
||||
#include "aclnnop/aclnn_mean.h"
|
||||
#include "aclnnop/aclnn_prod.h"
|
||||
#include "aclnnop/aclnn_triu.h"
|
||||
#include "aclnnop/aclnn_s_where.h"
|
||||
#include "aclnnop/aclnn_random.h"
|
||||
#include "aclnnop/aclnn_normal.h"
|
||||
#include "aclnnop/aclnn_permute.h"
|
||||
#include "aclnnop/aclnn_max_pool2d_with_indices.h"
|
||||
#include "aclnnop/aclnn_max_pool2d_with_indices_backward.h"
|
||||
#include "aclnnop/aclnn_avgpool2d.h"
|
||||
#include "aclnnop/aclnn_avgpool2d_backward.h"
|
||||
#include "aclnnop/aclnn_flip.h"
|
||||
#include "aclnnop/aclnn_cat.h"
|
||||
#include "aclnnop/aclnn_gather.h"
|
||||
#include "aclnnop/aclnn_cumsum.h"
|
||||
#include "aclnnop/aclnn_index.h"
|
||||
#include "aclnnop/aclnn_scatter.h"
|
||||
#include "aclnnop/aclnn_index.h"
|
||||
#include "aclnnop/aclnn_strided_slice_assign_v2.h"
|
||||
#include "aclnnop/aclnn_slice_v2.h"
|
||||
#include "aclnnop/aclnn_index_put_impl.h"
|
||||
#include "aclnnop/aclnn_range.h"
|
||||
#include "aclnnop/aclnn_relu.h"
|
||||
#include "aclnnop/aclnn_dropout.h"
|
||||
#include "aclnnop/aclnn_dropout_backward.h"
|
||||
#include "aclnnop/aclnn_leaky_relu.h"
|
||||
#include "aclnnop/aclnn_leaky_relu_backward.h"
|
||||
#include "aclnnop/aclnn_uniform.h"
|
||||
#include "aclnnop/aclnn_silu.h"
|
||||
#include "aclnnop/aclnn_silu_backward.h"
|
||||
#include "aclnnop/aclnn_sigmoid.h"
|
||||
#include "aclnnop/aclnn_sigmoid_backward.h"
|
||||
#include "aclnnop/aclnn_embedding.h"
|
||||
#include "aclnnop/aclnn_embedding_dense_backward.h"
|
||||
#include "aclnnop/aclnn_masked_scatter.h"
|
||||
#include "aclnnop/aclnn_masked_select.h"
|
||||
#include "aclnnop/aclnn_split_with_size.h"
|
||||
#include "aclnnop/aclnn_flash_attention_score.h"
|
||||
#include "aclnnop/aclnn_flash_attention_score_grad.h"
|
||||
#include "aclnnop/aclnn_softmax.h"
|
||||
#include "aclnnop/aclnn_softmax_backward.h"
|
||||
#include "aclnnop/aclnn_batch_norm.h"
|
||||
#include "aclnnop/aclnn_batch_norm_backward.h"
|
||||
#include "aclnnop/aclnn_layer_norm.h"
|
||||
#include "aclnnop/aclnn_apply_rotary_pos_emb.h"
|
||||
#include "aclnnop/aclnn_stack.h"
|
||||
#include "aclnnop/aclnn_nan_to_num.h"
|
||||
|
||||
#define CHECK_RET(cond, return_expr) \
|
||||
do \
|
||||
{ \
|
||||
if (!(cond)) \
|
||||
{ \
|
||||
return_expr; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define LOG_PRINT(message, ...) \
|
||||
do \
|
||||
{ \
|
||||
printf(message, ##__VA_ARGS__); \
|
||||
} while (0)
|
||||
|
||||
int64_t GetShapeSize(const std::vector<int64_t> &shape);
|
||||
|
||||
void PrintOutResult(std::vector<int64_t> &shape, void **deviceAddr);
|
||||
|
||||
//int Init(int32_t deviceId);
|
||||
|
||||
/*
|
||||
template <typename T>
|
||||
int CreateAclTensor(const std::vector<T>& hostData, const std::vector<int64_t>& shape, void** deviceAddr,
|
||||
aclDataType dataType, aclTensor** tensor);
|
||||
*/
|
|
@ -0,0 +1,33 @@
|
|||
#pragma once
|
||||
#include <acl/aclops/binary_op_acl.h>
|
||||
#include <acl/aclops/unary_op_acl.h>
|
||||
#include <acl/aclops/conv_op_acl.h>
|
||||
#include <acl/aclops/ternary_op_acl.h>
|
||||
#include <acl/aclops/reduce_op_acl.h>
|
||||
#include <acl/aclops/expand_op_acl.h>
|
||||
#include <acl/aclops/getitem_op_acl.h>
|
||||
#include <acl/aclops/setitem_op_acl.h>
|
||||
#include <acl/aclops/matmul_op_acl.h>
|
||||
#include <acl/aclops/random_op_acl.h>
|
||||
#include <acl/aclops/bmm_op_acl.h>
|
||||
#include <acl/aclops/pool_op_acl.h>
|
||||
#include <acl/aclops/flip_op_acl.h>
|
||||
#include <acl/aclops/concat_op_acl.h>
|
||||
#include <acl/aclops/gather_scatter_op_acl.h>
|
||||
#include <acl/aclops/cumsum_op_acl.h>
|
||||
#include <acl/aclops/index_op_acl.h>
|
||||
#include <acl/aclops/where_op_acl.h>
|
||||
#include <acl/aclops/floor_op_acl.h>
|
||||
#include <acl/aclops/transpose_op_acl.h>
|
||||
#include <acl/aclops/flashattention_op_acl.h>
|
||||
#include <acl/aclops/relu_op_acl.h>
|
||||
#include <acl/aclops/dropout_op_acl.h>
|
||||
#include <acl/aclops/silu_op_acl.h>
|
||||
#include <acl/aclops/sigmoid_op_acl.h>
|
||||
#include <acl/aclops/softmax_op_acl.h>
|
||||
#include <acl/aclops/stack_op_acl.h>
|
||||
#include <acl/aclops/nantonum_op_acl.h>
|
||||
#include <acl/aclops/rope_op_acl.h>
|
||||
#include <acl/aclops/triu_op_acl.h>
|
||||
#include <acl/aclops/embedding_op_acl.h>
|
||||
#include <acl/aclops/norms_op_acl.h>
|
|
@ -0,0 +1,56 @@
|
|||
#pragma once
|
||||
#include "utils.h"
|
||||
#include "acl_jittor.h"
|
||||
|
||||
namespace jittor
|
||||
{
|
||||
extern int sync_run;
|
||||
class BaseOpRunner
|
||||
{
|
||||
protected:
|
||||
vector<Var *> in_;
|
||||
vector<Var *> out_;
|
||||
|
||||
int ret = -1;
|
||||
uint64_t workspaceSize = 0;
|
||||
aclOpExecutor *executor;
|
||||
bool is_group_op = false;
|
||||
|
||||
std::vector<std::vector<int64_t>> inputShapes;
|
||||
std::vector<std::vector<int64_t>> outputShapes;
|
||||
|
||||
std::vector<aclTensor *> inputTensors;
|
||||
std::vector<aclTensor *> outputTensors;
|
||||
|
||||
public:
|
||||
string name;
|
||||
string jt_name;
|
||||
std::unique_ptr<AclOpAttr> op_attr;
|
||||
bool use_nchw = false;
|
||||
|
||||
BaseOpRunner(const string &name = "") : name(name) {}
|
||||
virtual ~BaseOpRunner() = default;
|
||||
|
||||
// Common functionality for adding input/output variables
|
||||
void add(Var *v, bool is_input);
|
||||
|
||||
virtual void setupInputDesc();
|
||||
|
||||
void cleanupDesc();
|
||||
|
||||
virtual void setupOutputDesc();
|
||||
|
||||
virtual void syncRun();
|
||||
|
||||
void checkRet(aclnnStatus ret);
|
||||
|
||||
// Base run method with common operator lookup logic
|
||||
void run();
|
||||
|
||||
protected:
|
||||
// Virtual method for specific operator execution
|
||||
virtual void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) = 0;
|
||||
void cleanupAttr();
|
||||
};
|
||||
|
||||
}
|
|
@ -0,0 +1,152 @@
|
|||
#include <acl/acl.h>
|
||||
#include <acl/acl_op_compiler.h>
|
||||
#include <Python.h>
|
||||
#include <pystate.h>
|
||||
#include <algorithm>
|
||||
#include <queue>
|
||||
#include <set>
|
||||
#include "common.h"
|
||||
#include "op.h"
|
||||
#include "acl_jittor.h"
|
||||
#include "ops/random_op.h"
|
||||
#include "ops/reduce_op.h"
|
||||
#include "ops/binary_op.h"
|
||||
#include "ops/broadcast_to_op.h"
|
||||
#include "ops/transpose_op.h"
|
||||
#include "ops/array_op.h"
|
||||
#include "ops/code_op.h"
|
||||
#include "fused_op.h"
|
||||
#include "ops/unary_op.h"
|
||||
#include "ops/ternary_op.h"
|
||||
#include "executor.h"
|
||||
#include "misc/cuda_flags.h"
|
||||
#include "mem/allocator.h"
|
||||
#include "op_compiler.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "opt/tuner_manager.h"
|
||||
#include "utils/str_utils.h"
|
||||
#include "aclnn/aclnn.h"
|
||||
#include "binary_op_acl.h"
|
||||
#include "base_op.h"
|
||||
|
||||
namespace jittor
|
||||
{
|
||||
extern int sync_run;
|
||||
// Common functionality for adding input/output variables
|
||||
void BaseOpRunner::add(Var *v, bool is_input)
|
||||
{
|
||||
if (is_input)
|
||||
{
|
||||
in_.push_back(v);
|
||||
}
|
||||
else
|
||||
{
|
||||
out_.push_back(v);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void BaseOpRunner::setupInputDesc()
|
||||
{
|
||||
auto input_num = in_.size();
|
||||
for (int input_idx = 0; input_idx < input_num; input_idx++)
|
||||
{
|
||||
std::vector<int64_t> shape;
|
||||
for (int j = 0; j < in_[input_idx]->shape.size(); j++)
|
||||
{
|
||||
shape.push_back(in_[input_idx]->shape[j]);
|
||||
}
|
||||
inputShapes.push_back(shape);
|
||||
}
|
||||
|
||||
for (int idx = 0; idx < input_num; idx++)
|
||||
{
|
||||
inputTensors.push_back(nullptr);
|
||||
auto ret = CreateAclTensor(inputShapes[idx], in_[idx]->mem_ptr, in_[idx]->size, get_dtype(in_[idx]->dtype()), &inputTensors[idx], use_nchw);
|
||||
CHECK_RET(ret == ACL_SUCCESS, return);
|
||||
}
|
||||
}
|
||||
|
||||
void BaseOpRunner::cleanupDesc()
|
||||
{
|
||||
auto input_num = in_.size();
|
||||
auto output_num = out_.size();
|
||||
for (int idx = 0; idx < input_num; idx++)
|
||||
{
|
||||
aclDestroyTensor(inputTensors[idx]);
|
||||
}
|
||||
for (int idx = 0; idx < output_num; idx++)
|
||||
{
|
||||
aclDestroyTensor(outputTensors[idx]);
|
||||
}
|
||||
}
|
||||
|
||||
void BaseOpRunner::setupOutputDesc()
|
||||
{
|
||||
auto output_num = out_.size();
|
||||
|
||||
for (int output_idx = 0; output_idx < output_num; output_idx++)
|
||||
{
|
||||
std::vector<int64_t> shape;
|
||||
for (int j = 0; j < out_[output_idx]->shape.size(); j++)
|
||||
{
|
||||
shape.push_back(out_[output_idx]->shape[j]);
|
||||
}
|
||||
outputShapes.push_back(shape);
|
||||
}
|
||||
|
||||
for (int idx = 0; idx < output_num; idx++)
|
||||
{
|
||||
outputTensors.push_back(nullptr);
|
||||
auto ret = CreateAclTensor(outputShapes[idx], out_[idx]->mem_ptr, out_[idx]->size, get_dtype(out_[idx]->dtype()), &outputTensors[idx], use_nchw);
|
||||
CHECK_RET(ret == ACL_SUCCESS, return);
|
||||
}
|
||||
}
|
||||
|
||||
void BaseOpRunner::syncRun()
|
||||
{
|
||||
if (sync_run)
|
||||
{
|
||||
// ret = aclrtSynchronizeStream(aclstream);
|
||||
// CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclrtSynchronizeStream failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
}
|
||||
}
|
||||
|
||||
void BaseOpRunner::checkRet(aclnnStatus ret)
|
||||
{
|
||||
if (ret != ACL_SUCCESS)
|
||||
{
|
||||
auto tmp_err_msg = aclGetRecentErrMsg();
|
||||
LOGir << name << ", " << tmp_err_msg;
|
||||
}
|
||||
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnxxxGetWorkspaceSize failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
}
|
||||
|
||||
// Base run method with common operator lookup logic
|
||||
void BaseOpRunner::run()
|
||||
{
|
||||
if (is_group_op)
|
||||
{
|
||||
auto it = aclOpFuncMap.find(name);
|
||||
if (it == aclOpFuncMap.end())
|
||||
{
|
||||
LOGir << "aclOpFuncMap Not supported op: " << name;
|
||||
throw std::runtime_error("Unsupported operation type.");
|
||||
}
|
||||
setupInputDesc();
|
||||
setupOutputDesc();
|
||||
executeOp(it);
|
||||
cleanupDesc();
|
||||
}
|
||||
else
|
||||
{
|
||||
auto it = aclOpFuncMap.find(name);
|
||||
setupInputDesc();
|
||||
setupOutputDesc();
|
||||
executeOp(it);
|
||||
cleanupDesc();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,124 @@
|
|||
#include <acl/acl.h>
|
||||
#include <acl/acl_op_compiler.h>
|
||||
#include <Python.h>
|
||||
#include <pystate.h>
|
||||
#include <algorithm>
|
||||
#include <queue>
|
||||
#include <set>
|
||||
#include "common.h"
|
||||
#include "op.h"
|
||||
#include "acl_jittor.h"
|
||||
#include "ops/random_op.h"
|
||||
#include "ops/reduce_op.h"
|
||||
#include "ops/binary_op.h"
|
||||
#include "ops/broadcast_to_op.h"
|
||||
#include "ops/transpose_op.h"
|
||||
#include "ops/array_op.h"
|
||||
#include "ops/code_op.h"
|
||||
#include "fused_op.h"
|
||||
#include "ops/unary_op.h"
|
||||
#include "ops/ternary_op.h"
|
||||
#include "executor.h"
|
||||
#include "misc/cuda_flags.h"
|
||||
#include "mem/allocator.h"
|
||||
#include "op_compiler.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "opt/tuner_manager.h"
|
||||
#include "utils/str_utils.h"
|
||||
#include "aclnn/aclnn.h"
|
||||
#include "binary_op_acl.h"
|
||||
|
||||
namespace jittor
|
||||
{
|
||||
BinaryOpRunner::BinaryOpRunner() : BaseOpRunner("binary")
|
||||
{
|
||||
use_nchw = false;
|
||||
is_group_op = true;
|
||||
}
|
||||
|
||||
void BinaryOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||
{
|
||||
aclScalar *alpha = nullptr;
|
||||
|
||||
if (name == string("Add") || name == string("Sub"))
|
||||
{
|
||||
if (get_dtype(in_[0]->dtype()) == ACL_FLOAT)
|
||||
{
|
||||
float alphaValue = 1.0;
|
||||
alpha = aclCreateScalar(&alphaValue, get_dtype(in_[0]->dtype()));
|
||||
}
|
||||
else if (get_dtype(in_[0]->dtype()) == ACL_FLOAT16)
|
||||
{
|
||||
__fp16 alphaValue = 1.0;
|
||||
alpha = aclCreateScalar(&alphaValue, get_dtype(in_[0]->dtype()));
|
||||
}
|
||||
else if (get_dtype(in_[0]->dtype()) == ACL_INT64)
|
||||
{
|
||||
int64_t alphaValue = 1;
|
||||
alpha = aclCreateScalar(&alphaValue, get_dtype(in_[0]->dtype()));
|
||||
}
|
||||
else if (get_dtype(in_[0]->dtype()) == ACL_INT32)
|
||||
{
|
||||
int alphaValue = 1;
|
||||
alpha = aclCreateScalar(&alphaValue, get_dtype(in_[0]->dtype()));
|
||||
}
|
||||
else if (get_dtype(in_[0]->dtype()) == ACL_INT8)
|
||||
{
|
||||
int8_t alphaValue = 1;
|
||||
alpha = aclCreateScalar(&alphaValue, get_dtype(in_[0]->dtype()));
|
||||
}
|
||||
else if (get_dtype(in_[0]->dtype()) == ACL_INT16)
|
||||
{
|
||||
int16_t alphaValue = 1;
|
||||
alpha = aclCreateScalar(&alphaValue, get_dtype(in_[0]->dtype()));
|
||||
}
|
||||
else if (get_dtype(in_[0]->dtype()) == ACL_UINT8)
|
||||
{
|
||||
uint8_t alphaValue = 1;
|
||||
alpha = aclCreateScalar(&alphaValue, get_dtype(in_[0]->dtype()));
|
||||
}
|
||||
else if (get_dtype(in_[0]->dtype()) == ACL_UINT16)
|
||||
{
|
||||
uint16_t alphaValue = 1;
|
||||
alpha = aclCreateScalar(&alphaValue, get_dtype(in_[0]->dtype()));
|
||||
}
|
||||
else if (get_dtype(in_[0]->dtype()) == ACL_UINT32)
|
||||
{
|
||||
uint32_t alphaValue = 1;
|
||||
alpha = aclCreateScalar(&alphaValue, get_dtype(in_[0]->dtype()));
|
||||
}
|
||||
else if (get_dtype(in_[0]->dtype()) == ACL_BOOL)
|
||||
{
|
||||
bool alphaValue = true;
|
||||
alpha = aclCreateScalar(&alphaValue, get_dtype(in_[0]->dtype()));
|
||||
}
|
||||
else
|
||||
{
|
||||
LOGf << "Not supported dtype: " << in_[0]->dtype();
|
||||
}
|
||||
|
||||
CHECK_RET(alpha != nullptr, return);
|
||||
ret = it->second.getWorkspaceSizeFuncAdd(inputTensors[0], inputTensors[1], alpha, outputTensors[0], &workspaceSize, &executor);
|
||||
}
|
||||
else
|
||||
|
||||
{
|
||||
ret = it->second.getWorkspaceSizeFuncBinary(inputTensors[0], inputTensors[1], outputTensors[0], &workspaceSize, &executor);
|
||||
}
|
||||
|
||||
checkRet(ret);
|
||||
|
||||
if (workspaceSize > 0)
|
||||
{
|
||||
mallocWorkSpace(workspaceSize);
|
||||
}
|
||||
|
||||
ret = it->second.executeFunc(workspaceAddr, workspaceSize, executor, aclstream);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnxxx failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
|
||||
syncRun();
|
||||
|
||||
aclDestroyScalar(alpha);
|
||||
return;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,14 @@
|
|||
#pragma once
|
||||
#include "utils.h"
|
||||
#include "base_op.h"
|
||||
|
||||
namespace jittor
|
||||
{
|
||||
struct BinaryOpRunner : public BaseOpRunner
|
||||
{
|
||||
BinaryOpRunner();
|
||||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
};
|
||||
}
|
|
@ -0,0 +1,128 @@
|
|||
import os
|
||||
from jittor_utils import env_or_try_find
|
||||
import jittor_utils
|
||||
import ctypes
|
||||
import glob
|
||||
import jittor.compiler as compiler
|
||||
import jittor as jt
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
from typing import Union
|
||||
from collections.abc import Sequence, Iterable
|
||||
|
||||
|
||||
def acl_cmd(name: str,
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None,
|
||||
extra_data: dict = {}):
|
||||
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
||||
|
||||
cuda_header = '''
|
||||
#include "acl/aclops/aclops.h"
|
||||
'''
|
||||
outputs_ = []
|
||||
if outputs is not None:
|
||||
outputs_ = outputs
|
||||
else:
|
||||
assert output_dtypes is not None
|
||||
assert output_shapes is not None
|
||||
assert len(output_dtypes) == len(output_shapes)
|
||||
for i in range(len(output_shapes)):
|
||||
outputs_.append(jt.empty(output_shapes[i], output_dtypes[i]))
|
||||
input_code = ''
|
||||
for i in range(len(inputs)):
|
||||
input_code += f"op.add(in{i}, true);\n"
|
||||
|
||||
return jt.code(outputs=outputs_,
|
||||
inputs=inputs,
|
||||
cuda_header=attr_header + cuda_header,
|
||||
cuda_src=f"""
|
||||
// aclop
|
||||
BatchMatMulOpRunner op;
|
||||
{input_code}
|
||||
op.add(out0, false);
|
||||
{attr_code}
|
||||
op.run();""",
|
||||
data=extra_data)
|
||||
|
||||
|
||||
class BmmACL(jt.Function):
|
||||
|
||||
def __init__(self, trans_x2=False):
|
||||
super(BmmACL, self).__init__()
|
||||
self.trans_x2 = trans_x2
|
||||
|
||||
def execute(self, x1, x2):
|
||||
self.input = [x1, x2]
|
||||
result = acl_cmd("BatchMatMul", [x1, x2],
|
||||
output_dtypes=[x1.dtype],
|
||||
output_shapes=[
|
||||
x1.shape[:-1] + x2.shape[-2:-1] if self.trans_x2
|
||||
else x1.shape[:-1] + x2.shape[-1:]
|
||||
],
|
||||
attr_code="op.jt_name=\"bmm_trans_1\";"
|
||||
if self.trans_x2 else "op.jt_name=\"bmm\";")[0]
|
||||
|
||||
return result
|
||||
|
||||
def grad(self, grad_output):
|
||||
x1, x2 = self.input
|
||||
if len(x1) != len(x2):
|
||||
reshape_grad_x2 = True
|
||||
else:
|
||||
reshape_grad_x2 = False
|
||||
grad_x1 = acl_cmd(
|
||||
"BatchMatMul", [grad_output, x2],
|
||||
output_dtypes=[x1.dtype],
|
||||
output_shapes=[
|
||||
grad_output.shape[:-1] + x2.shape[-2:-1] if not self.trans_x2
|
||||
else grad_output.shape[:-1] + x1.shape[-1:]
|
||||
],
|
||||
attr_code="op.jt_name=\"bmm_trans_1\";"
|
||||
if not self.trans_x2 else "op.jt_name=\"bmm\";")[0]
|
||||
if self.trans_x2:
|
||||
if reshape_grad_x2:
|
||||
output_shape = grad_output.shape[1:-2] + grad_output.shape[
|
||||
-1:] + x1.shape[-1:]
|
||||
grad_x2 = acl_cmd("BatchMatMul", [
|
||||
grad_output.reshape(-1, grad_output.shape[-1]),
|
||||
x1.reshape(-1, x1.shape[-1])
|
||||
],
|
||||
output_dtypes=[x2.dtype],
|
||||
output_shapes=[output_shape],
|
||||
attr_code="op.jt_name=\"bmm_trans_0\";")[0]
|
||||
else:
|
||||
output_shape = grad_output.shape[:-2] + grad_output.shape[
|
||||
-1:] + x1.shape[-1:]
|
||||
grad_x2 = acl_cmd("BatchMatMul", [grad_output, x1],
|
||||
output_dtypes=[x2.dtype],
|
||||
output_shapes=[output_shape],
|
||||
attr_code="op.jt_name=\"bmm_trans_0\";")[0]
|
||||
else:
|
||||
if reshape_grad_x2:
|
||||
output_shape = x1.shape[1:-2] + x1.shape[
|
||||
-1:] + grad_output.shape[-1:]
|
||||
grad_x2 = acl_cmd("BatchMatMul", [
|
||||
x1.reshape(-1, x1.shape[-1]),
|
||||
grad_output.reshape(-1, grad_output.shape[-1])
|
||||
],
|
||||
output_dtypes=[x2.dtype],
|
||||
output_shapes=[output_shape],
|
||||
attr_code="op.jt_name=\"bmm_trans_0\";")[0]
|
||||
else:
|
||||
output_shape = x1.shape[:-2] + x1.shape[
|
||||
-1:] + grad_output.shape[-1:]
|
||||
grad_x2 = acl_cmd("BatchMatMul", [x1, grad_output],
|
||||
output_dtypes=[x2.dtype],
|
||||
output_shapes=[output_shape],
|
||||
attr_code="op.jt_name=\"bmm_trans_0\";")[0]
|
||||
if len(grad_x1.shape) > len(x1.shape):
|
||||
grad_x1 = grad_x1.sum(0)
|
||||
if len(grad_x2.shape) > len(x2.shape):
|
||||
grad_x2 = grad_x2.sum(0)
|
||||
return grad_x1, grad_x2
|
|
@ -0,0 +1,77 @@
|
|||
#pragma once
|
||||
#include <acl/acl.h>
|
||||
#include <acl/acl_op_compiler.h>
|
||||
#include <Python.h>
|
||||
#include <pystate.h>
|
||||
#include <algorithm>
|
||||
#include <queue>
|
||||
#include <set>
|
||||
#include "common.h"
|
||||
#include "op.h"
|
||||
#include "acl_jittor.h"
|
||||
#include "ops/random_op.h"
|
||||
#include "ops/reduce_op.h"
|
||||
#include "ops/binary_op.h"
|
||||
#include "ops/broadcast_to_op.h"
|
||||
#include "ops/transpose_op.h"
|
||||
#include "ops/array_op.h"
|
||||
#include "ops/code_op.h"
|
||||
#include "fused_op.h"
|
||||
#include "ops/unary_op.h"
|
||||
#include "ops/ternary_op.h"
|
||||
#include "executor.h"
|
||||
#include "misc/cuda_flags.h"
|
||||
#include "mem/allocator.h"
|
||||
#include "op_compiler.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "opt/tuner_manager.h"
|
||||
#include "utils/str_utils.h"
|
||||
#include "aclnn/aclnn.h"
|
||||
#include "bmm_op_acl.h"
|
||||
|
||||
namespace jittor
|
||||
{
|
||||
BatchMatMulOpRunner::BatchMatMulOpRunner() : BaseOpRunner("BatchMatMulMatMul")
|
||||
{
|
||||
}
|
||||
void BatchMatMulOpRunner::setupInputDesc()
|
||||
{
|
||||
auto input_num = in_.size();
|
||||
for (int input_idx = 0; input_idx < input_num; input_idx++)
|
||||
{
|
||||
std::vector<int64_t> shape;
|
||||
for (int j = 0; j < in_[input_idx]->shape.size(); j++)
|
||||
{
|
||||
shape.push_back(in_[input_idx]->shape[j]);
|
||||
}
|
||||
inputShapes.push_back(shape);
|
||||
}
|
||||
for (int idx = 0; idx < input_num; idx++)
|
||||
{
|
||||
inputTensors.push_back(nullptr);
|
||||
if ((jt_name == "bmm_trans_1" && idx == 1) || (jt_name == "bmm_trans_0" && idx == 0))
|
||||
{
|
||||
auto ret = CreateFakeTransAclTensor(inputShapes[idx], in_[idx]->mem_ptr, in_[idx]->size, get_dtype(in_[idx]->dtype()), &inputTensors[idx], use_nchw);
|
||||
CHECK_RET(ret == ACL_SUCCESS, return);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto ret = CreateAclTensor(inputShapes[idx], in_[idx]->mem_ptr, in_[idx]->size, get_dtype(in_[idx]->dtype()), &inputTensors[idx], use_nchw);
|
||||
CHECK_RET(ret == ACL_SUCCESS, return);
|
||||
}
|
||||
}
|
||||
}
|
||||
void BatchMatMulOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||
{
|
||||
|
||||
ret = aclnnBatchMatMulGetWorkspaceSize(inputTensors[0], inputTensors[1], outputTensors[0], 1, &workspaceSize, &executor);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnBatchMatmulGetWorkspaceSize failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
if (workspaceSize > 0)
|
||||
{
|
||||
mallocWorkSpace(workspaceSize);
|
||||
}
|
||||
ret = aclnnBatchMatMul(workspaceAddr, workspaceSize, executor, aclstream);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnbatchMatmul failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
syncRun();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,17 @@
|
|||
#pragma once
|
||||
#include "utils.h"
|
||||
#include "base_op.h"
|
||||
|
||||
namespace jittor
|
||||
{
|
||||
class BatchMatMulOpRunner : public BaseOpRunner
|
||||
{
|
||||
|
||||
protected:
|
||||
void setupInputDesc() override;
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
BatchMatMulOpRunner();
|
||||
};
|
||||
}
|
|
@ -0,0 +1,186 @@
|
|||
import os
|
||||
from jittor_utils import env_or_try_find
|
||||
import jittor_utils
|
||||
import ctypes
|
||||
import glob
|
||||
import jittor.compiler as compiler
|
||||
import jittor as jt
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
from typing import Union
|
||||
from collections.abc import Sequence, Iterable
|
||||
|
||||
|
||||
def concat_cmd(name: str,
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
||||
|
||||
cuda_header = '''
|
||||
#include "acl/aclops/aclops.h"
|
||||
'''
|
||||
outputs_ = []
|
||||
if outputs is not None:
|
||||
outputs_ = outputs
|
||||
else:
|
||||
assert output_dtypes is not None
|
||||
assert output_shapes is not None
|
||||
assert len(output_dtypes) == len(output_shapes)
|
||||
for i in range(len(output_shapes)):
|
||||
outputs_.append(jt.empty(output_shapes[i], output_dtypes[i]))
|
||||
input_code = ''
|
||||
for i in range(len(inputs)):
|
||||
input_code += f"op.add(in{i}, true);\n"
|
||||
|
||||
output_code = ''
|
||||
for i in range(len(outputs_)):
|
||||
output_code += f"op.add(out{i}, false);\n"
|
||||
return jt.code(outputs=outputs_,
|
||||
inputs=inputs,
|
||||
cuda_header=attr_header + cuda_header,
|
||||
cuda_src=f"""
|
||||
|
||||
// aclop
|
||||
{name}OpRunner op;
|
||||
{input_code}
|
||||
{output_code}
|
||||
{attr_code}
|
||||
op.run();""")
|
||||
|
||||
|
||||
class ConcatACL(jt.Function):
|
||||
|
||||
def __init__(self):
|
||||
super(ConcatACL, self).__init__()
|
||||
|
||||
def __call__(self, *args):
|
||||
assert isinstance(args[0], (list, tuple))
|
||||
assert isinstance(args[1], int)
|
||||
if jt.flags.no_grad:
|
||||
return self.execute(*args)
|
||||
backup = args
|
||||
args = list(args)
|
||||
taped_inputs = []
|
||||
taped_outputs = []
|
||||
input_mask = [-1] * (len(args[0]) + 1)
|
||||
newargs = [list(), args[1]]
|
||||
for i, v in enumerate(args[0]):
|
||||
if isinstance(v, jt.Var):
|
||||
if v.is_stop_grad():
|
||||
# -2 in input_mask represents it is stop_grad
|
||||
input_mask[i] = -2
|
||||
newargs[0].append(v)
|
||||
continue
|
||||
v = v.tape()
|
||||
newargs[0].append(v)
|
||||
input_mask[i] = len(taped_inputs)
|
||||
taped_inputs.append(v)
|
||||
|
||||
ori_res = self.execute(*newargs)
|
||||
if not isinstance(ori_res, Sequence):
|
||||
res = [ori_res]
|
||||
else:
|
||||
res = list(ori_res)
|
||||
output_mask = [-1] * len(res)
|
||||
for i, v in enumerate(res):
|
||||
if isinstance(v, jt.Var):
|
||||
v = v.tape()
|
||||
output_mask[i] = len(taped_outputs)
|
||||
res[i] = v
|
||||
taped_outputs.append(v)
|
||||
self.input_mask = input_mask
|
||||
self.output_mask = output_mask
|
||||
# tape output and input together so
|
||||
# backward treat them as one operator
|
||||
jt.tape_together(taped_inputs, taped_outputs, self._grad)
|
||||
if isinstance(ori_res, Sequence):
|
||||
return res
|
||||
else:
|
||||
return res[0]
|
||||
|
||||
def execute(self, input_tensors, dim=0):
|
||||
for _ in input_tensors:
|
||||
if not (-_.ndim <= dim < _.ndim):
|
||||
print(_.shape, dim)
|
||||
raise ValueError("dim out of range")
|
||||
|
||||
if dim < 0:
|
||||
dim += input_tensors[0].ndim
|
||||
|
||||
self.input = input_tensors
|
||||
self.dim = dim
|
||||
for i in range(len(input_tensors)):
|
||||
if input_tensors[i].dtype != input_tensors[0].dtype:
|
||||
raise ValueError("All input tensors must have the same dtype")
|
||||
if input_tensors[i].shape[:dim] != input_tensors[
|
||||
0].shape[:dim] or input_tensors[i].shape[
|
||||
dim + 1:] != input_tensors[0].shape[dim + 1:]:
|
||||
raise ValueError("All input tensors must have the same shape")
|
||||
attr_code = f"""
|
||||
op.jt_name = "concat";
|
||||
ConcatAttr *attr = new ConcatAttr();
|
||||
attr->tensorNum = {len(input_tensors)};
|
||||
attr->dim = {dim};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
result = concat_cmd(
|
||||
"Concat",
|
||||
input_tensors,
|
||||
output_dtypes=[input_tensors[0].dtype],
|
||||
output_shapes=[
|
||||
jt.empty(self.calculate_output_shape(input_tensors, dim)).shape
|
||||
],
|
||||
attr_code=attr_code)[0]
|
||||
return result
|
||||
|
||||
def _grad(self, *args):
|
||||
new_args = ((args[i] if i >= 0 else None) for i in self.output_mask)
|
||||
ret = self.grad(*new_args)
|
||||
new_ret = []
|
||||
for i, r in enumerate(ret):
|
||||
j = self.input_mask[i]
|
||||
if j < 0:
|
||||
# -2 in input_mask represents it is stop_grad
|
||||
assert r is None or j==-2, f"{type(self)}'s {i}-th returned grad should be None, "\
|
||||
"because the input value is not jittor variable."
|
||||
else:
|
||||
new_ret.append(r)
|
||||
return new_ret
|
||||
|
||||
def grad(self, grad_output):
|
||||
grad_inputs = self.split_grad(grad_output, self.input, self.dim)
|
||||
return grad_inputs
|
||||
|
||||
def calculate_output_shape(self, input_tensors, axis):
|
||||
shape = list(input_tensors[0].shape)
|
||||
for tensor in input_tensors[1:]:
|
||||
shape[axis] += tensor.shape[axis]
|
||||
return tuple(shape)
|
||||
|
||||
def split_grad(self, grad_output, input_tensors, axis):
|
||||
offset = []
|
||||
shapeVec = []
|
||||
dtypeVec = []
|
||||
for tensor in input_tensors:
|
||||
offset.append(tensor.shape[axis])
|
||||
dtypeVec.append(tensor.dtype)
|
||||
shapeVec.append(tensor.shape)
|
||||
|
||||
attr_code = f"""
|
||||
op.jt_name = "splitwithsize";
|
||||
auto *attr = new SplitWithSizeAttr();
|
||||
attr->splitSize = {{ {", ".join(map(str, offset))} }};
|
||||
attr->dim = {axis};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
|
||||
result = concat_cmd("SplitWithSize", [grad_output],
|
||||
output_dtypes=dtypeVec,
|
||||
output_shapes=shapeVec,
|
||||
attr_code=attr_code)
|
||||
return result
|
|
@ -0,0 +1,89 @@
|
|||
#pragma once
|
||||
#include <acl/acl.h>
|
||||
#include <acl/acl_op_compiler.h>
|
||||
#include <Python.h>
|
||||
#include <pystate.h>
|
||||
#include <algorithm>
|
||||
#include <queue>
|
||||
#include <set>
|
||||
#include "common.h"
|
||||
#include "op.h"
|
||||
#include "acl_jittor.h"
|
||||
#include "ops/random_op.h"
|
||||
#include "ops/reduce_op.h"
|
||||
#include "ops/binary_op.h"
|
||||
#include "ops/broadcast_to_op.h"
|
||||
#include "ops/transpose_op.h"
|
||||
#include "ops/array_op.h"
|
||||
#include "ops/code_op.h"
|
||||
#include "fused_op.h"
|
||||
#include "ops/unary_op.h"
|
||||
#include "ops/ternary_op.h"
|
||||
#include "executor.h"
|
||||
#include "misc/cuda_flags.h"
|
||||
#include "mem/allocator.h"
|
||||
#include "op_compiler.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "opt/tuner_manager.h"
|
||||
#include "utils/str_utils.h"
|
||||
#include "aclnn/aclnn.h"
|
||||
#include "concat_op_acl.h"
|
||||
|
||||
namespace jittor
|
||||
{
|
||||
ConcatOpRunner::ConcatOpRunner() : BaseOpRunner("Concat")
|
||||
{
|
||||
}
|
||||
|
||||
void ConcatOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||
{
|
||||
auto input_num = in_.size();
|
||||
std::vector<aclTensor *> concatTensorList = {};
|
||||
for (int i = 0; i < input_num; i++)
|
||||
{
|
||||
concatTensorList.push_back(inputTensors[i]);
|
||||
}
|
||||
auto concatTensorListInput = aclCreateTensorList(&concatTensorList[0], input_num);
|
||||
auto attr = dynamic_cast<ConcatAttr *>(op_attr.get());
|
||||
ret = aclnnCatGetWorkspaceSize(concatTensorListInput, attr->dim, outputTensors[0], &workspaceSize, &executor);
|
||||
checkRet(ret);
|
||||
|
||||
if (workspaceSize > 0)
|
||||
{
|
||||
mallocWorkSpace(workspaceSize);
|
||||
}
|
||||
|
||||
ret = aclnnCat(workspaceAddr, workspaceSize, executor, aclstream);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnCat failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
|
||||
syncRun();
|
||||
return;
|
||||
}
|
||||
|
||||
SplitWithSizeOpRunner::SplitWithSizeOpRunner() : BaseOpRunner("SplitWithSize")
|
||||
{
|
||||
}
|
||||
|
||||
void SplitWithSizeOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||
{
|
||||
auto output_num = out_.size();
|
||||
auto attr = dynamic_cast<SplitWithSizeAttr *>(op_attr.get());
|
||||
auto splitSize = aclCreateIntArray(attr->splitSize.data(), attr->splitSize.size());
|
||||
auto tensorList = aclCreateTensorList(&outputTensors[0], output_num);
|
||||
ret = aclnnSplitWithSizeGetWorkspaceSize(inputTensors[0], splitSize, attr->dim, tensorList, &workspaceSize, &executor);
|
||||
|
||||
checkRet(ret);
|
||||
|
||||
if (workspaceSize > 0)
|
||||
{
|
||||
mallocWorkSpace(workspaceSize);
|
||||
}
|
||||
|
||||
ret = aclnnSplitWithSize(workspaceAddr, workspaceSize, executor, aclstream);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnSplitWithSize failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
|
||||
syncRun();
|
||||
return;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,26 @@
|
|||
#pragma once
|
||||
#include "utils.h"
|
||||
#include "base_op.h"
|
||||
|
||||
namespace jittor
|
||||
{
|
||||
class ConcatOpRunner : public BaseOpRunner
|
||||
{
|
||||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
ConcatOpRunner();
|
||||
};
|
||||
|
||||
class SplitWithSizeOpRunner : public BaseOpRunner
|
||||
{
|
||||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
SplitWithSizeOpRunner();
|
||||
};
|
||||
}
|
|
@ -0,0 +1,160 @@
|
|||
import os
|
||||
import jittor_utils
|
||||
from jittor_utils import env_or_try_find
|
||||
import ctypes
|
||||
import glob
|
||||
import jittor as jt
|
||||
import jittor.compiler as compiler
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
from typing import Union
|
||||
from collections.abc import Sequence, Iterable
|
||||
|
||||
|
||||
def _ntuple(n):
|
||||
|
||||
def parse(x):
|
||||
if isinstance(x, Iterable):
|
||||
return x
|
||||
return tuple([x] * n)
|
||||
|
||||
return parse
|
||||
|
||||
|
||||
_pair = _ntuple(2)
|
||||
|
||||
|
||||
def conv_cmd(name: str,
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
||||
|
||||
cuda_header = '''
|
||||
#include "acl/aclops/aclops.h"
|
||||
'''
|
||||
outputs_ = []
|
||||
if outputs is not None:
|
||||
outputs_ = outputs
|
||||
else:
|
||||
assert output_dtypes is not None
|
||||
assert output_shapes is not None
|
||||
assert len(output_dtypes) == len(output_shapes)
|
||||
for i in range(len(output_shapes)):
|
||||
outputs_.append(jt.empty(output_shapes[i], output_dtypes[i]))
|
||||
input_code = ''
|
||||
for i in range(len(inputs)):
|
||||
input_code += f"op.add(in{i}, true);\n"
|
||||
|
||||
output_code = ''
|
||||
for i in range(len(outputs_)):
|
||||
output_code += f"op.add(out{i}, false);\n"
|
||||
return jt.code(outputs=outputs_,
|
||||
inputs=inputs,
|
||||
cuda_header=attr_header + cuda_header,
|
||||
cuda_src=f"""
|
||||
|
||||
// aclop
|
||||
{name}OpRunner op;
|
||||
{input_code}
|
||||
{output_code}
|
||||
{attr_code}
|
||||
op.run();""")
|
||||
|
||||
class ConvACL(jt.Function):
|
||||
|
||||
def execute(self,
|
||||
x,
|
||||
weight,
|
||||
bias=None,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1):
|
||||
self.input = x
|
||||
self.weight = weight
|
||||
self.bias = bias
|
||||
padding = _pair(padding)
|
||||
stride = _pair(stride)
|
||||
dilation = _pair(dilation)
|
||||
out_channels = weight.shape[0]
|
||||
if groups <= 0:
|
||||
raise ValueError("groups must be a positive integer")
|
||||
self.padding = padding
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
self.groups = groups
|
||||
attr_code = f"""
|
||||
op.jt_name = "conv2d";
|
||||
ConvAttr *attr = new ConvAttr();
|
||||
attr->convStrides = {{ {stride[0]}, {stride[1]} }};
|
||||
attr->convPads = {{ {padding[0]}, {padding[1]} }};
|
||||
attr->convDilations = {{ {dilation[0]}, {dilation[1]} }};
|
||||
attr->group = {groups};
|
||||
attr->convOutPads = {{1,1}};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
input_height, input_width = x.shape[-2:]
|
||||
kernel_height, kernel_width = weight.shape[-2:]
|
||||
|
||||
output_height = (input_height + 2 * padding[0] - dilation[0] *
|
||||
(kernel_height - 1) - 1) // stride[0] + 1
|
||||
output_width = (input_width + 2 * padding[1] - dilation[1] *
|
||||
(kernel_width - 1) - 1) // stride[1] + 1
|
||||
|
||||
output_shape = (x.shape[0], out_channels, output_height, output_width)
|
||||
|
||||
inputs = [x, weight]
|
||||
if bias is not None:
|
||||
inputs.append(bias)
|
||||
result = conv_cmd(
|
||||
"Conv2d",
|
||||
inputs,
|
||||
output_dtypes=[x.dtype],
|
||||
output_shapes=[output_shape],
|
||||
attr_code=attr_code,
|
||||
)[0]
|
||||
return result
|
||||
|
||||
def grad(self, grad_output):
|
||||
x = self.input
|
||||
weight = self.weight
|
||||
bias = self.bias
|
||||
inputs = [grad_output, x, weight]
|
||||
if bias is not None:
|
||||
inputs.append(bias)
|
||||
output_shapes = [x.shape, weight.shape]
|
||||
output_dtypes = [x.dtype, weight.dtype]
|
||||
if bias is not None:
|
||||
output_shapes.append(bias.shape)
|
||||
output_dtypes.append(bias.dtype)
|
||||
else:
|
||||
output_shapes.append([weight.shape[0]])
|
||||
output_dtypes.append(x.dtype)
|
||||
padding = self.padding
|
||||
stride = self.stride
|
||||
dilation = self.dilation
|
||||
groups = self.groups
|
||||
attr_code = f"""
|
||||
op.jt_name = "conv2dbackward";
|
||||
ConvAttr *attr = new ConvAttr();
|
||||
attr->convStrides = {{ {stride[0]}, {stride[1]} }};
|
||||
attr->convPads = {{ {padding[0]}, {padding[1]} }};
|
||||
attr->convDilations = {{ {dilation[0]}, {dilation[1]} }};
|
||||
attr->group = {groups};
|
||||
attr->convOutPads = {{ 1,1}};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
results = conv_cmd("Conv2dBackward",
|
||||
inputs,
|
||||
output_dtypes=output_dtypes,
|
||||
output_shapes=output_shapes,
|
||||
attr_code=attr_code)
|
||||
if self.bias is None:
|
||||
return results[0], results[1]
|
||||
|
||||
return results
|
|
@ -0,0 +1,152 @@
|
|||
#include <acl/acl.h>
|
||||
#include <acl/acl_op_compiler.h>
|
||||
#include <Python.h>
|
||||
#include <pystate.h>
|
||||
#include <algorithm>
|
||||
#include <queue>
|
||||
#include <set>
|
||||
#include "common.h"
|
||||
#include "op.h"
|
||||
#include "acl_jittor.h"
|
||||
#include "ops/random_op.h"
|
||||
#include "ops/reduce_op.h"
|
||||
#include "ops/binary_op.h"
|
||||
#include "ops/broadcast_to_op.h"
|
||||
#include "ops/transpose_op.h"
|
||||
#include "ops/array_op.h"
|
||||
#include "ops/code_op.h"
|
||||
#include "fused_op.h"
|
||||
#include "ops/unary_op.h"
|
||||
#include "ops/ternary_op.h"
|
||||
#include "executor.h"
|
||||
#include "misc/cuda_flags.h"
|
||||
#include "mem/allocator.h"
|
||||
#include "op_compiler.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "opt/tuner_manager.h"
|
||||
#include "utils/str_utils.h"
|
||||
#include "aclnn/aclnn.h"
|
||||
#include "conv_op_acl.h"
|
||||
|
||||
namespace jittor
|
||||
{
|
||||
Conv2dOpRunner::Conv2dOpRunner() : BaseOpRunner("Conv2d")
|
||||
{
|
||||
use_nchw = true;
|
||||
}
|
||||
|
||||
void Conv2dOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||
{
|
||||
// for conv
|
||||
aclIntArray *strides = nullptr;
|
||||
aclIntArray *pads = nullptr;
|
||||
aclIntArray *outPads = nullptr;
|
||||
aclIntArray *dilations = nullptr;
|
||||
auto attr = dynamic_cast<ConvAttr *>(op_attr.get());
|
||||
strides = aclCreateIntArray(attr->convStrides.data(), 2);
|
||||
pads = aclCreateIntArray(attr->convPads.data(), 2);
|
||||
outPads = aclCreateIntArray(attr->convOutPads.data(), 2);
|
||||
dilations = aclCreateIntArray(attr->convDilations.data(), 2);
|
||||
|
||||
aclTensor *bias = nullptr;
|
||||
|
||||
auto input_num = in_.size();
|
||||
if (input_num == 3)
|
||||
bias = inputTensors[2];
|
||||
|
||||
ret = aclnnConvolutionGetWorkspaceSize(inputTensors[0], inputTensors[1], bias, strides, pads, dilations, false, outPads, attr->group, outputTensors[0], 0, &workspaceSize, &executor);
|
||||
|
||||
checkRet(ret);
|
||||
|
||||
if (workspaceSize > 0)
|
||||
{
|
||||
mallocWorkSpace(workspaceSize);
|
||||
}
|
||||
|
||||
ret = aclnnConvolution(workspaceAddr, workspaceSize, executor, aclstream);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnConvolution failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
|
||||
syncRun();
|
||||
|
||||
aclDestroyIntArray(strides);
|
||||
aclDestroyIntArray(pads);
|
||||
aclDestroyIntArray(outPads);
|
||||
aclDestroyIntArray(dilations);
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
Conv2dBackwardOpRunner::Conv2dBackwardOpRunner() : BaseOpRunner("Conv2dBackward")
|
||||
{
|
||||
use_nchw = true;
|
||||
}
|
||||
|
||||
void Conv2dBackwardOpRunner::setupOutputDesc()
|
||||
{
|
||||
auto output_num = out_.size();
|
||||
|
||||
for (int output_idx = 0; output_idx < output_num; output_idx++)
|
||||
{
|
||||
std::vector<int64_t> shape;
|
||||
for (int j = 0; j < out_[output_idx]->shape.size(); j++)
|
||||
{
|
||||
shape.push_back(out_[output_idx]->shape[j]);
|
||||
}
|
||||
outputShapes.push_back(shape);
|
||||
}
|
||||
|
||||
for (int idx = 0; idx < 2; idx++)
|
||||
{
|
||||
outputTensors.push_back(nullptr);
|
||||
auto ret = CreateAclTensor(outputShapes[idx], out_[idx]->mem_ptr, out_[idx]->size, get_dtype(out_[idx]->dtype()), &outputTensors[idx], use_nchw);
|
||||
CHECK_RET(ret == ACL_SUCCESS, return);
|
||||
}
|
||||
// biasgrad nd format
|
||||
{
|
||||
outputTensors.push_back(nullptr);
|
||||
auto ret = CreateAclTensor(outputShapes[2], out_[2]->mem_ptr, out_[2]->size, get_dtype(out_[2]->dtype()), &outputTensors[2], false);
|
||||
CHECK_RET(ret == ACL_SUCCESS, return);
|
||||
}
|
||||
}
|
||||
|
||||
void Conv2dBackwardOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||
{
|
||||
// for conv
|
||||
aclIntArray *strides = nullptr;
|
||||
aclIntArray *pads = nullptr;
|
||||
aclIntArray *outPads = nullptr;
|
||||
aclIntArray *dilations = nullptr;
|
||||
auto attr = dynamic_cast<ConvAttr *>(op_attr.get());
|
||||
strides = aclCreateIntArray(attr->convStrides.data(), 2);
|
||||
pads = aclCreateIntArray(attr->convPads.data(), 2);
|
||||
outPads = aclCreateIntArray(attr->convOutPads.data(), 2);
|
||||
dilations = aclCreateIntArray(attr->convDilations.data(), 2);
|
||||
bool outputMask[3] = {true, true, true};
|
||||
auto input_num = in_.size();
|
||||
if (input_num == 3)
|
||||
{
|
||||
outputMask[2] = false;
|
||||
}
|
||||
aclBoolArray *outMask = aclCreateBoolArray(outputMask, 3);
|
||||
auto biasSizes = aclCreateIntArray(&outputShapes[2][0], outputShapes[2].size());
|
||||
ret = aclnnConvolutionBackwardGetWorkspaceSize(inputTensors[0], inputTensors[1], inputTensors[2], biasSizes, strides, pads, dilations, false, outPads, attr->group, outMask, 0, outputTensors[0], outputTensors[1], outputTensors[2], &workspaceSize, &executor);
|
||||
|
||||
checkRet(ret);
|
||||
|
||||
if (workspaceSize > 0)
|
||||
{
|
||||
mallocWorkSpace(workspaceSize);
|
||||
}
|
||||
|
||||
ret = aclnnConvolutionBackward(workspaceAddr, workspaceSize, executor, aclstream);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnConvolutionBackward failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
|
||||
syncRun();
|
||||
|
||||
aclDestroyIntArray(strides);
|
||||
aclDestroyIntArray(pads);
|
||||
aclDestroyIntArray(outPads);
|
||||
aclDestroyIntArray(dilations);
|
||||
return;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
#pragma once
|
||||
#include "utils.h"
|
||||
#include "base_op.h"
|
||||
|
||||
namespace jittor
|
||||
{
|
||||
class Conv2dOpRunner : public BaseOpRunner
|
||||
{
|
||||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
Conv2dOpRunner();
|
||||
};
|
||||
|
||||
class Conv2dBackwardOpRunner : public BaseOpRunner
|
||||
{
|
||||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
void setupOutputDesc() override;
|
||||
|
||||
public:
|
||||
Conv2dBackwardOpRunner();
|
||||
};
|
||||
}
|
|
@ -0,0 +1,101 @@
|
|||
import os
|
||||
from jittor_utils import env_or_try_find
|
||||
import jittor_utils
|
||||
import ctypes
|
||||
import glob
|
||||
import jittor.compiler as compiler
|
||||
import jittor as jt
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
from typing import Union
|
||||
from collections.abc import Sequence, Iterable
|
||||
|
||||
|
||||
def cumsum_cmd(name: str,
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
||||
|
||||
cuda_header = '''
|
||||
#include "acl/aclops/aclops.h"
|
||||
'''
|
||||
outputs_ = []
|
||||
if outputs is not None:
|
||||
outputs_ = outputs
|
||||
else:
|
||||
assert output_dtypes is not None
|
||||
assert output_shapes is not None
|
||||
assert len(output_dtypes) == len(output_shapes)
|
||||
for i in range(len(output_shapes)):
|
||||
outputs_.append(jt.empty(output_shapes[i], output_dtypes[i]))
|
||||
input_code = ''
|
||||
for i in range(len(inputs)):
|
||||
input_code += f"op.add(in{i}, true);\n"
|
||||
|
||||
output_code = ''
|
||||
for i in range(len(outputs_)):
|
||||
output_code += f"op.add(out{i}, false);\n"
|
||||
return jt.code(outputs=outputs_,
|
||||
inputs=inputs,
|
||||
cuda_header=attr_header + cuda_header,
|
||||
cuda_src=f"""
|
||||
|
||||
// aclop
|
||||
{name}OpRunner op;
|
||||
{input_code}
|
||||
{output_code}
|
||||
{attr_code}
|
||||
op.run();""")
|
||||
|
||||
|
||||
class CumsumACL(jt.Function):
|
||||
|
||||
def __init__(self):
|
||||
super(CumsumACL, self).__init__()
|
||||
|
||||
def execute(self, input, dim=-1):
|
||||
self.dim = dim
|
||||
attr_code = f"""
|
||||
op.jt_name = "cumsum";
|
||||
GatherAttr *attr = new GatherAttr();
|
||||
attr->dim = {dim};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
result = cumsum_cmd("Cumsum", [input],
|
||||
output_dtypes=[input.dtype],
|
||||
output_shapes=[input.shape],
|
||||
attr_code=attr_code)[0]
|
||||
return result
|
||||
|
||||
def grad(self, grad_output):
|
||||
cumsum_attr_code = f"""
|
||||
op.jt_name = "cumsum";
|
||||
GatherAttr *attr = new GatherAttr();
|
||||
attr->dim = {self.dim};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
flip_attr_code = f"""
|
||||
op.jt_name = "flip";
|
||||
ReduceAttr *attr = new ReduceAttr();
|
||||
attr->axes = {{{self.dim}}};
|
||||
attr->prod_dim = {{{1}}};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
flipped_grad_output = cumsum_cmd("Flip", [grad_output],
|
||||
output_dtypes=[grad_output.dtype],
|
||||
output_shapes=[grad_output.shape],
|
||||
attr_code=flip_attr_code)[0]
|
||||
cumulative_grad = cumsum_cmd("Cumsum", [flipped_grad_output],
|
||||
output_dtypes=[grad_output.dtype],
|
||||
output_shapes=[grad_output.shape],
|
||||
attr_code=cumsum_attr_code)[0]
|
||||
grad_input = cumsum_cmd("Flip", [cumulative_grad],
|
||||
output_dtypes=[grad_output.dtype],
|
||||
output_shapes=[grad_output.shape],
|
||||
attr_code=flip_attr_code)[0]
|
||||
return grad_input
|
|
@ -0,0 +1,57 @@
|
|||
#pragma once
|
||||
#include <acl/acl.h>
|
||||
#include <acl/acl_op_compiler.h>
|
||||
#include <Python.h>
|
||||
#include <pystate.h>
|
||||
#include <algorithm>
|
||||
#include <queue>
|
||||
#include <set>
|
||||
#include "common.h"
|
||||
#include "op.h"
|
||||
#include "acl_jittor.h"
|
||||
#include "ops/random_op.h"
|
||||
#include "ops/reduce_op.h"
|
||||
#include "ops/binary_op.h"
|
||||
#include "ops/broadcast_to_op.h"
|
||||
#include "ops/transpose_op.h"
|
||||
#include "ops/array_op.h"
|
||||
#include "ops/code_op.h"
|
||||
#include "fused_op.h"
|
||||
#include "ops/unary_op.h"
|
||||
#include "ops/ternary_op.h"
|
||||
#include "executor.h"
|
||||
#include "misc/cuda_flags.h"
|
||||
#include "mem/allocator.h"
|
||||
#include "op_compiler.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "opt/tuner_manager.h"
|
||||
#include "utils/str_utils.h"
|
||||
#include "aclnn/aclnn.h"
|
||||
#include "cumsum_op_acl.h"
|
||||
|
||||
namespace jittor
|
||||
{
|
||||
CumsumOpRunner::CumsumOpRunner() : BaseOpRunner("Cumsum")
|
||||
{
|
||||
}
|
||||
|
||||
void CumsumOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||
{
|
||||
auto attr = dynamic_cast<GatherAttr *>(op_attr.get());
|
||||
ret = aclnnCumsumGetWorkspaceSize(inputTensors[0], attr->dim, get_dtype(out_[0]->dtype()), outputTensors[0], &workspaceSize, &executor);
|
||||
|
||||
checkRet(ret);
|
||||
|
||||
if (workspaceSize > 0)
|
||||
{
|
||||
mallocWorkSpace(workspaceSize);
|
||||
}
|
||||
|
||||
ret = aclnnCumsum(workspaceAddr, workspaceSize, executor, aclstream);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnCumsum failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
|
||||
syncRun();
|
||||
return;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,17 @@
|
|||
#pragma once
|
||||
#include "utils.h"
|
||||
#include "base_op.h"
|
||||
|
||||
namespace jittor
|
||||
{
|
||||
class CumsumOpRunner : public BaseOpRunner
|
||||
{
|
||||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
CumsumOpRunner();
|
||||
};
|
||||
|
||||
}
|
|
@ -0,0 +1,94 @@
|
|||
import os
|
||||
from jittor_utils import env_or_try_find
|
||||
import jittor_utils
|
||||
import ctypes
|
||||
import glob
|
||||
import jittor.compiler as compiler
|
||||
import jittor as jt
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
from typing import Union
|
||||
from collections.abc import Sequence, Iterable
|
||||
|
||||
|
||||
def dropout_cmd(name: str,
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
||||
|
||||
cuda_header = '''
|
||||
#include "acl/aclops/aclops.h"
|
||||
'''
|
||||
outputs_ = []
|
||||
if outputs is not None:
|
||||
outputs_ = outputs
|
||||
else:
|
||||
assert output_dtypes is not None
|
||||
assert output_shapes is not None
|
||||
assert len(output_dtypes) == len(output_shapes)
|
||||
for i in range(len(output_shapes)):
|
||||
outputs_.append(jt.empty(output_shapes[i], output_dtypes[i]))
|
||||
input_code = ''
|
||||
for i in range(len(inputs)):
|
||||
input_code += f"op.add(in{i}, true);\n"
|
||||
|
||||
output_code = ''
|
||||
for i in range(len(outputs_)):
|
||||
output_code += f"op.add(out{i}, false);\n"
|
||||
return jt.code(outputs=outputs_,
|
||||
inputs=inputs,
|
||||
cuda_header=attr_header + cuda_header,
|
||||
cuda_src=f"""
|
||||
|
||||
// aclop
|
||||
{name}OpRunner op;
|
||||
{input_code}
|
||||
{output_code}
|
||||
{attr_code}
|
||||
op.run();""")
|
||||
|
||||
|
||||
class DropoutACL(jt.Function):
|
||||
|
||||
def __init__(self):
|
||||
super(DropoutACL, self).__init__()
|
||||
|
||||
def execute(self, x, p=0.5, is_train=False):
|
||||
self.input = x
|
||||
num_elements = x.numel()
|
||||
aligned_elements = (num_elements + 127) // 128 * 128
|
||||
mask_shape = (aligned_elements // 8, )
|
||||
attr_code = f"""
|
||||
op.jt_name = "dropout";
|
||||
DropoutAttr *attr = new DropoutAttr();
|
||||
attr->p = {p};
|
||||
attr->train = {"true" if is_train else "false"};
|
||||
attr->seed = 0;
|
||||
attr->offset = 0;
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
result = dropout_cmd("Dropout", [x],
|
||||
output_dtypes=[x.dtype, "uint8"],
|
||||
output_shapes=[x.shape, mask_shape],
|
||||
attr_code=attr_code)
|
||||
self.maskout = result[1]
|
||||
return result[0]
|
||||
|
||||
def grad(self, grad_output):
|
||||
attr_code = f"""
|
||||
op.jt_name = "dropoutbackward";
|
||||
DropoutAttr *attr = new DropoutAttr();
|
||||
attr->scale = 1.0;
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
grad_input = dropout_cmd("DropoutBackward",
|
||||
[grad_output, self.maskout],
|
||||
output_dtypes=[grad_output.dtype],
|
||||
output_shapes=[grad_output.shape],
|
||||
attr_code=attr_code)[0]
|
||||
return grad_input
|
|
@ -0,0 +1,82 @@
|
|||
#pragma once
|
||||
#include <acl/acl.h>
|
||||
#include <acl/acl_op_compiler.h>
|
||||
#include <Python.h>
|
||||
#include <pystate.h>
|
||||
#include <algorithm>
|
||||
#include <queue>
|
||||
#include <set>
|
||||
#include "common.h"
|
||||
#include "op.h"
|
||||
#include "acl_jittor.h"
|
||||
#include "ops/random_op.h"
|
||||
#include "ops/reduce_op.h"
|
||||
#include "ops/binary_op.h"
|
||||
#include "ops/broadcast_to_op.h"
|
||||
#include "ops/transpose_op.h"
|
||||
#include "ops/array_op.h"
|
||||
#include "ops/code_op.h"
|
||||
#include "fused_op.h"
|
||||
#include "ops/unary_op.h"
|
||||
#include "ops/ternary_op.h"
|
||||
#include "executor.h"
|
||||
#include "misc/cuda_flags.h"
|
||||
#include "mem/allocator.h"
|
||||
#include "op_compiler.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "opt/tuner_manager.h"
|
||||
#include "utils/str_utils.h"
|
||||
#include "aclnn/aclnn.h"
|
||||
#include "dropout_op_acl.h"
|
||||
|
||||
namespace jittor
|
||||
{
|
||||
DropoutOpRunner::DropoutOpRunner() : BaseOpRunner("Dropout")
|
||||
{
|
||||
}
|
||||
|
||||
void DropoutOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||
{
|
||||
auto attr = dynamic_cast<DropoutAttr *>(op_attr.get());
|
||||
ret = aclnnDropoutGetWorkspaceSize(inputTensors[0], attr->p, attr->train, attr->seed, attr->offset, outputTensors[0], outputTensors[1], &workspaceSize, &executor);
|
||||
|
||||
checkRet(ret);
|
||||
|
||||
if (workspaceSize > 0)
|
||||
{
|
||||
mallocWorkSpace(workspaceSize);
|
||||
}
|
||||
|
||||
ret = aclnnDropout(workspaceAddr, workspaceSize, executor, aclstream);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnDropout failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
|
||||
syncRun();
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
DropoutBackwardOpRunner::DropoutBackwardOpRunner() : BaseOpRunner("DropoutBackward")
|
||||
{
|
||||
}
|
||||
|
||||
void DropoutBackwardOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||
{
|
||||
auto attr = dynamic_cast<DropoutAttr *>(op_attr.get());
|
||||
ret = aclnnDropoutBackwardGetWorkspaceSize(inputTensors[0], inputTensors[1], attr->scale, outputTensors[0], &workspaceSize, &executor);
|
||||
|
||||
checkRet(ret);
|
||||
|
||||
if (workspaceSize > 0)
|
||||
{
|
||||
mallocWorkSpace(workspaceSize);
|
||||
}
|
||||
|
||||
ret = aclnnDropoutBackward(workspaceAddr, workspaceSize, executor, aclstream);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnDropoutBackward failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
|
||||
syncRun();
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
#pragma once
|
||||
#include "utils.h"
|
||||
#include "base_op.h"
|
||||
|
||||
namespace jittor
|
||||
{
|
||||
class DropoutOpRunner : public BaseOpRunner
|
||||
{
|
||||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
DropoutOpRunner();
|
||||
};
|
||||
|
||||
class DropoutBackwardOpRunner : public BaseOpRunner
|
||||
{
|
||||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
DropoutBackwardOpRunner();
|
||||
};
|
||||
|
||||
}
|
|
@ -0,0 +1,91 @@
|
|||
import os
|
||||
from jittor_utils import env_or_try_find
|
||||
import jittor_utils
|
||||
import ctypes
|
||||
import glob
|
||||
import jittor.compiler as compiler
|
||||
import jittor as jt
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
from typing import Union
|
||||
from collections.abc import Sequence, Iterable
|
||||
|
||||
def embedding_cmd(name: str,
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
||||
|
||||
cuda_header = '''
|
||||
#include "acl/aclops/aclops.h"
|
||||
'''
|
||||
outputs_ = []
|
||||
if outputs is not None:
|
||||
outputs_ = outputs
|
||||
else:
|
||||
assert output_dtypes is not None
|
||||
assert output_shapes is not None
|
||||
assert len(output_dtypes) == len(output_shapes)
|
||||
for i in range(len(output_shapes)):
|
||||
outputs_.append(jt.empty(output_shapes[i], output_dtypes[i]))
|
||||
input_code = ''
|
||||
for i in range(len(inputs)):
|
||||
input_code += f"op.add(in{i}, true);\n"
|
||||
|
||||
output_code = ''
|
||||
for i in range(len(outputs_)):
|
||||
output_code += f"op.add(out{i}, false);\n"
|
||||
return jt.code(outputs=outputs_,
|
||||
inputs=inputs,
|
||||
cuda_header=attr_header + cuda_header,
|
||||
cuda_src=f"""
|
||||
|
||||
// aclop
|
||||
{name}OpRunner op;
|
||||
{input_code}
|
||||
{output_code}
|
||||
{attr_code}
|
||||
op.run();""")
|
||||
|
||||
class EmbeddingACL(jt.Function):
|
||||
|
||||
def __init__(self):
|
||||
super(EmbeddingACL, self).__init__()
|
||||
|
||||
def execute(
|
||||
self,
|
||||
indices,
|
||||
weight,
|
||||
):
|
||||
inputs = [weight, indices]
|
||||
self.indices = indices
|
||||
self.weight_shape = weight.shape
|
||||
output_shape = list(indices.shape) + list(weight.shape[1:])
|
||||
outputs = [jt.empty(output_shape, weight.dtype)]
|
||||
attr_code = f"""
|
||||
op.jt_name = "embedding";
|
||||
"""
|
||||
result = embedding_cmd("Embedding",
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
attr_code=attr_code)[0]
|
||||
return result
|
||||
|
||||
def grad(self, grad_output):
|
||||
inputs = [grad_output, self.indices]
|
||||
outputs = [jt.empty(self.weight_shape, grad_output.dtype)]
|
||||
attr_code = f"""
|
||||
op.jt_name = "embeddingbackward";
|
||||
EmbeddingAttr *attr = new EmbeddingAttr();
|
||||
attr->numEmbeddings = {self.weight_shape[0]};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
grad_weight = embedding_cmd("EmbeddingBackward",
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
attr_code=attr_code)[0]
|
||||
return None, grad_weight
|
|
@ -0,0 +1,82 @@
|
|||
#pragma once
|
||||
#include <acl/acl.h>
|
||||
#include <acl/acl_op_compiler.h>
|
||||
#include <Python.h>
|
||||
#include <pystate.h>
|
||||
#include <algorithm>
|
||||
#include <queue>
|
||||
#include <set>
|
||||
#include "common.h"
|
||||
#include "op.h"
|
||||
#include "acl_jittor.h"
|
||||
#include "ops/random_op.h"
|
||||
#include "ops/reduce_op.h"
|
||||
#include "ops/binary_op.h"
|
||||
#include "ops/broadcast_to_op.h"
|
||||
#include "ops/transpose_op.h"
|
||||
#include "ops/array_op.h"
|
||||
#include "ops/code_op.h"
|
||||
#include "fused_op.h"
|
||||
#include "ops/unary_op.h"
|
||||
#include "ops/ternary_op.h"
|
||||
#include "executor.h"
|
||||
#include "misc/cuda_flags.h"
|
||||
#include "mem/allocator.h"
|
||||
#include "op_compiler.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "opt/tuner_manager.h"
|
||||
#include "utils/str_utils.h"
|
||||
#include "aclnn/aclnn.h"
|
||||
#include "embedding_op_acl.h"
|
||||
|
||||
namespace jittor
|
||||
{
|
||||
EmbeddingOpRunner::EmbeddingOpRunner() : BaseOpRunner("Embedding")
|
||||
{
|
||||
}
|
||||
|
||||
void EmbeddingOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||
{
|
||||
ret = aclnnEmbeddingGetWorkspaceSize(inputTensors[0], inputTensors[1], outputTensors[0], &workspaceSize, &executor);
|
||||
|
||||
checkRet(ret);
|
||||
|
||||
if (workspaceSize > 0)
|
||||
{
|
||||
mallocWorkSpace(workspaceSize);
|
||||
}
|
||||
|
||||
ret = aclnnEmbedding(workspaceAddr, workspaceSize, executor, aclstream);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnEmbedding failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
|
||||
syncRun();
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
EmbeddingBackwardOpRunner::EmbeddingBackwardOpRunner() : BaseOpRunner("EmbeddingBackward")
|
||||
{
|
||||
}
|
||||
|
||||
void EmbeddingBackwardOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||
{
|
||||
auto attr = dynamic_cast<EmbeddingAttr *>(op_attr.get());
|
||||
auto numEmbeddings = attr->numEmbeddings;
|
||||
ret = aclnnEmbeddingDenseBackwardGetWorkspaceSize(inputTensors[0], inputTensors[1], numEmbeddings, 0, false, outputTensors[0], &workspaceSize, &executor);
|
||||
|
||||
checkRet(ret);
|
||||
|
||||
if (workspaceSize > 0)
|
||||
{
|
||||
mallocWorkSpace(workspaceSize);
|
||||
}
|
||||
|
||||
ret = aclnnEmbeddingDenseBackward(workspaceAddr, workspaceSize, executor, aclstream);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnEmbeddingDenseBackward failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
|
||||
syncRun();
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,25 @@
|
|||
#pragma once
|
||||
#include "utils.h"
|
||||
#include "base_op.h"
|
||||
|
||||
namespace jittor
|
||||
{
|
||||
class EmbeddingOpRunner : public BaseOpRunner
|
||||
{
|
||||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
public:
|
||||
EmbeddingOpRunner();
|
||||
};
|
||||
|
||||
class EmbeddingBackwardOpRunner : public BaseOpRunner
|
||||
{
|
||||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
public:
|
||||
EmbeddingBackwardOpRunner();
|
||||
};
|
||||
|
||||
}
|
|
@ -0,0 +1,58 @@
|
|||
#include <acl/acl.h>
|
||||
#include <acl/acl_op_compiler.h>
|
||||
#include <Python.h>
|
||||
#include <pystate.h>
|
||||
#include <algorithm>
|
||||
#include <queue>
|
||||
#include <set>
|
||||
#include "common.h"
|
||||
#include "op.h"
|
||||
#include "acl_jittor.h"
|
||||
#include "ops/random_op.h"
|
||||
#include "ops/reduce_op.h"
|
||||
#include "ops/binary_op.h"
|
||||
#include "ops/broadcast_to_op.h"
|
||||
#include "ops/transpose_op.h"
|
||||
#include "ops/array_op.h"
|
||||
#include "ops/code_op.h"
|
||||
#include "fused_op.h"
|
||||
#include "ops/unary_op.h"
|
||||
#include "ops/ternary_op.h"
|
||||
#include "executor.h"
|
||||
#include "misc/cuda_flags.h"
|
||||
#include "mem/allocator.h"
|
||||
#include "op_compiler.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "opt/tuner_manager.h"
|
||||
#include "utils/str_utils.h"
|
||||
#include "aclnn/aclnn.h"
|
||||
#include "expand_op_acl.h"
|
||||
|
||||
namespace jittor
|
||||
{
|
||||
ExpandOpRunner::ExpandOpRunner() : BaseOpRunner("ternary")
|
||||
{
|
||||
use_nchw = false;
|
||||
}
|
||||
|
||||
void ExpandOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||
{
|
||||
aclIntArray *size = nullptr;
|
||||
size = aclCreateIntArray(&outputShapes[0][0], outputShapes[0].size());
|
||||
ret = aclnnExpandGetWorkspaceSize(inputTensors[0], size, outputTensors[0], &workspaceSize, &executor);
|
||||
|
||||
checkRet(ret);
|
||||
|
||||
if (workspaceSize > 0)
|
||||
{
|
||||
mallocWorkSpace(workspaceSize);
|
||||
}
|
||||
|
||||
ret = aclnnExpand(workspaceAddr, workspaceSize, executor, aclstream);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnExpand failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
|
||||
aclDestroyIntArray(size);
|
||||
|
||||
return;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,14 @@
|
|||
#pragma once
|
||||
#include "utils.h"
|
||||
#include "base_op.h"
|
||||
|
||||
namespace jittor
|
||||
{
|
||||
struct ExpandOpRunner : public BaseOpRunner
|
||||
{
|
||||
ExpandOpRunner();
|
||||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
};
|
||||
}
|
|
@ -0,0 +1,209 @@
|
|||
import os
|
||||
from jittor_utils import env_or_try_find
|
||||
import jittor_utils
|
||||
import ctypes
|
||||
import glob
|
||||
import jittor.compiler as compiler
|
||||
import jittor as jt
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
from typing import Union
|
||||
from collections.abc import Sequence, Iterable
|
||||
|
||||
|
||||
def flashattention_cmd(name: str,
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
||||
|
||||
cuda_header = '''
|
||||
#include "acl/aclops/aclops.h"
|
||||
'''
|
||||
outputs_ = []
|
||||
if outputs is not None:
|
||||
outputs_ = outputs
|
||||
else:
|
||||
assert output_dtypes is not None
|
||||
assert output_shapes is not None
|
||||
assert len(output_dtypes) == len(output_shapes)
|
||||
for i in range(len(output_shapes)):
|
||||
outputs_.append(jt.empty(output_shapes[i], output_dtypes[i]))
|
||||
input_code = ''
|
||||
for i in range(len(inputs)):
|
||||
input_code += f"op.add(in{i}, true);\n"
|
||||
|
||||
output_code = ''
|
||||
for i in range(len(outputs_)):
|
||||
output_code += f"op.add(out{i}, false);\n"
|
||||
return jt.code(outputs=outputs_,
|
||||
inputs=inputs,
|
||||
cuda_header=attr_header + cuda_header,
|
||||
cuda_src=f"""
|
||||
|
||||
// aclop
|
||||
{name}OpRunner op;
|
||||
{input_code}
|
||||
{output_code}
|
||||
{attr_code}
|
||||
op.run();""")
|
||||
|
||||
|
||||
class FlashAttentionACL(jt.Function):
|
||||
|
||||
def __init__(self,
|
||||
headnum,
|
||||
layout="BNSD",
|
||||
prefix=None,
|
||||
qstart=None,
|
||||
kvstart=None,
|
||||
scale=1.0,
|
||||
prob=1.0,
|
||||
pretokens=2147483647,
|
||||
nexttokens=2147483647,
|
||||
innerprecise=0,
|
||||
sparsemode=0,
|
||||
psetype=1):
|
||||
self.headnum = headnum
|
||||
self.layout = layout
|
||||
self.scale = scale
|
||||
self.prob = prob
|
||||
self.pretokens = pretokens
|
||||
self.nexttokens = nexttokens
|
||||
self.innerprecise = innerprecise
|
||||
self.sparsemode = sparsemode
|
||||
self.psetype = psetype
|
||||
self.prefix = prefix
|
||||
self.qstart = qstart
|
||||
self.kvstart = kvstart
|
||||
|
||||
def execute(
|
||||
self,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
realshift=None,
|
||||
dropMask=None,
|
||||
paddingMask=None,
|
||||
attenMask=None,
|
||||
):
|
||||
if self.layout == 'BSH':
|
||||
B, SQ, H = q.shape
|
||||
SKV = k.shape[1]
|
||||
N = self.headnum
|
||||
D = H / N
|
||||
elif self.layout == 'SBH':
|
||||
SQ, B, H = q.shape
|
||||
SKV = k.shape[0]
|
||||
N = self.headnum
|
||||
D = H / N
|
||||
elif self.layout == 'BSND':
|
||||
B, SQ, N, D = q.shape
|
||||
SKV = k.shape[1]
|
||||
elif self.layout == 'BNSD':
|
||||
B, N, SQ, D = q.shape
|
||||
SKV = k.shape[2]
|
||||
else:
|
||||
raise ValueError(f"got invalid input layout {self.layout}")
|
||||
|
||||
output_shape = (B, N, SQ, 8)
|
||||
|
||||
self.q = q
|
||||
self.k = k
|
||||
self.v = v
|
||||
|
||||
self.prefix = self.prefix if self.prefix else [0 for _ in range(B)]
|
||||
self.qstart = self.qstart if self.qstart else [0 for _ in range(B)]
|
||||
self.kvstart = self.kvstart if self.kvstart else [0 for _ in range(B)]
|
||||
|
||||
self.hasRealshift = (not realshift == None)
|
||||
self.hasDropmask = (not dropMask == None)
|
||||
self.hasPaddingmask = (not paddingMask == None)
|
||||
self.hasAttenmask = (not attenMask == None)
|
||||
|
||||
# 待定,目前设为nullptr
|
||||
self.realshift = realshift if realshift else jt.zeros(B, N, SQ, SKV)
|
||||
self.dropMask = dropMask if dropMask else jt.ones(B, N, SQ, SKV)
|
||||
self.paddingMask = paddingMask if paddingMask else jt.zeros(
|
||||
B, N, SQ, SKV)
|
||||
self.attenMask = attenMask if attenMask else jt.zeros(SQ, SKV)
|
||||
|
||||
attr_code = f"""
|
||||
op.jt_name = "flashattention";
|
||||
FlashAttentionAttr *attr = new FlashAttentionAttr();
|
||||
attr->scale = {self.scale};
|
||||
attr->keepProb = {self.prob};
|
||||
attr->preToken = {self.pretokens};
|
||||
attr->nextToken = {self.nexttokens};
|
||||
attr->headNum = {self.headnum};
|
||||
attr->inputLayout = "{self.layout}";
|
||||
attr->innerPrecise = {self.innerprecise};
|
||||
attr->sparseMode = {self.sparsemode};
|
||||
attr->psetype = {self.psetype};
|
||||
attr->prefix = {{ {", ".join(map(str, self.prefix))} }};
|
||||
attr->qStartIdx = {{ {", ".join(map(str, self.qstart))} }};
|
||||
attr->kvStartIdx = {{ {", ".join(map(str, self.kvstart))} }};
|
||||
attr->hasRealshift = {"true" if self.hasRealshift else "false"};
|
||||
attr->hasDropmask = {"true" if self.hasDropmask else "false"};
|
||||
attr->hasPaddingmask = {"true" if self.hasPaddingmask else "false"};
|
||||
attr->hasAttentmask = {"true" if self.hasAttenmask else "false"};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
|
||||
inputs = [
|
||||
q, k, v, self.realshift, self.dropMask, self.paddingMask,
|
||||
self.attenMask
|
||||
]
|
||||
|
||||
result = flashattention_cmd(
|
||||
"FlashAttention",
|
||||
inputs,
|
||||
output_dtypes=["float", "float", q.dtype],
|
||||
output_shapes=[output_shape, output_shape, q.shape],
|
||||
attr_code=attr_code)
|
||||
|
||||
self.maxout = result[0]
|
||||
self.sumout = result[1]
|
||||
self.attenout = result[2]
|
||||
|
||||
return self.attenout
|
||||
|
||||
def grad(self, dy):
|
||||
attr_code = f"""
|
||||
op.jt_name = "flashattentionbackward";
|
||||
FlashAttentionAttr *attr = new FlashAttentionAttr();
|
||||
attr->scale = {self.scale};
|
||||
attr->keepProb = {self.prob};
|
||||
attr->preToken = {self.pretokens};
|
||||
attr->nextToken = {self.nexttokens};
|
||||
attr->headNum = {self.headnum};
|
||||
attr->inputLayout = "{self.layout}";
|
||||
attr->innerPrecise = {self.innerprecise};
|
||||
attr->sparseMode = {self.sparsemode};
|
||||
attr->psetype = {self.psetype};
|
||||
attr->prefix = {{ {", ".join(map(str, self.prefix))} }};
|
||||
attr->qStartIdx = {{ {", ".join(map(str, self.qstart))} }};
|
||||
attr->kvStartIdx = {{ {", ".join(map(str, self.kvstart))} }};
|
||||
attr->hasRealshift = {"true" if self.hasRealshift else "false"};
|
||||
attr->hasDropmask = {"true" if self.hasDropmask else "false"};
|
||||
attr->hasPaddingmask = {"true" if self.hasPaddingmask else "false"};
|
||||
attr->hasAttentmask = {"true" if self.hasAttenmask else "false"};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
inputs = [
|
||||
self.q, self.k, self.v, dy, self.realshift, self.dropMask,
|
||||
self.paddingMask, self.attenMask, self.maxout, self.sumout,
|
||||
self.attenout
|
||||
]
|
||||
|
||||
result = flashattention_cmd(
|
||||
"FlashAttentionBackward",
|
||||
inputs,
|
||||
output_dtypes=[self.q.dtype, self.k.dtype, self.v.dtype],
|
||||
output_shapes=[self.q.shape, self.k.shape, self.v.shape],
|
||||
attr_code=attr_code)
|
||||
return result
|
|
@ -0,0 +1,88 @@
|
|||
#pragma once
|
||||
#include <acl/acl.h>
|
||||
#include <acl/acl_op_compiler.h>
|
||||
#include <Python.h>
|
||||
#include <pystate.h>
|
||||
#include <algorithm>
|
||||
#include <queue>
|
||||
#include <set>
|
||||
#include "common.h"
|
||||
#include "op.h"
|
||||
#include "acl_jittor.h"
|
||||
#include "ops/random_op.h"
|
||||
#include "ops/reduce_op.h"
|
||||
#include "ops/binary_op.h"
|
||||
#include "ops/broadcast_to_op.h"
|
||||
#include "ops/transpose_op.h"
|
||||
#include "ops/array_op.h"
|
||||
#include "ops/code_op.h"
|
||||
#include "fused_op.h"
|
||||
#include "ops/unary_op.h"
|
||||
#include "ops/ternary_op.h"
|
||||
#include "executor.h"
|
||||
#include "misc/cuda_flags.h"
|
||||
#include "mem/allocator.h"
|
||||
#include "op_compiler.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "opt/tuner_manager.h"
|
||||
#include "utils/str_utils.h"
|
||||
#include "aclnn/aclnn.h"
|
||||
#include "flashattention_op_acl.h"
|
||||
|
||||
namespace jittor
|
||||
{
|
||||
FlashAttentionOpRunner::FlashAttentionOpRunner() : BaseOpRunner("FlashAttention")
|
||||
{
|
||||
}
|
||||
|
||||
void FlashAttentionOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||
{
|
||||
auto attr = dynamic_cast<FlashAttentionAttr *>(op_attr.get());
|
||||
auto prefix = aclCreateIntArray(attr->prefix.data(), attr->prefix.size());
|
||||
auto qstart = aclCreateIntArray(attr->qStartIdx.data(), attr->qStartIdx.size());
|
||||
auto kvstart = aclCreateIntArray(attr->kvStartIdx.data(), attr->kvStartIdx.size());
|
||||
char *layout = const_cast<char *>(attr->inputLayout.data());
|
||||
ret = aclnnFlashAttentionScoreV2GetWorkspaceSize(inputTensors[0], inputTensors[1], inputTensors[2], attr->hasRealshift ? inputTensors[3] : nullptr, attr->hasDropmask ? inputTensors[4] : nullptr, nullptr, attr->hasAttentmask ? inputTensors[6] : nullptr, prefix, qstart, kvstart, attr->scale, attr->keepProb, attr->preToken, attr->nextToken, attr->headNum, layout, attr->innerPrecise, attr->sparseMode, attr->psetype, outputTensors[0], outputTensors[1], nullptr, outputTensors[2], &workspaceSize, &executor);
|
||||
|
||||
checkRet(ret);
|
||||
|
||||
if (workspaceSize > 0)
|
||||
{
|
||||
mallocWorkSpace(workspaceSize);
|
||||
}
|
||||
|
||||
ret = aclnnFlashAttentionScoreV2(workspaceAddr, workspaceSize, executor, aclstream);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnFlashAttentionScoreV2 failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
|
||||
syncRun();
|
||||
return;
|
||||
}
|
||||
|
||||
FlashAttentionBackwardOpRunner::FlashAttentionBackwardOpRunner() : BaseOpRunner("FlashAttentionBackward")
|
||||
{
|
||||
}
|
||||
|
||||
void FlashAttentionBackwardOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||
{
|
||||
auto attr = dynamic_cast<FlashAttentionAttr *>(op_attr.get());
|
||||
auto prefix = aclCreateIntArray(attr->prefix.data(), attr->prefix.size());
|
||||
auto qstart = aclCreateIntArray(attr->qStartIdx.data(), attr->qStartIdx.size());
|
||||
auto kvstart = aclCreateIntArray(attr->kvStartIdx.data(), attr->kvStartIdx.size());
|
||||
char *layout = const_cast<char *>(attr->inputLayout.data());
|
||||
ret = aclnnFlashAttentionScoreGradV2GetWorkspaceSize(inputTensors[0], inputTensors[1], inputTensors[2], inputTensors[3], attr->hasRealshift ? inputTensors[4] : nullptr, attr->hasDropmask ? inputTensors[5] : nullptr, nullptr, attr->hasAttentmask ? inputTensors[7] : nullptr, inputTensors[8], inputTensors[9], nullptr, inputTensors[10], prefix, qstart, kvstart, attr->scale, attr->keepProb, attr->preToken, attr->nextToken, attr->headNum, layout, attr->innerPrecise, attr->sparseMode, attr->psetype, outputTensors[0], outputTensors[1], outputTensors[2], nullptr, &workspaceSize, &executor);
|
||||
|
||||
checkRet(ret);
|
||||
|
||||
if (workspaceSize > 0)
|
||||
{
|
||||
mallocWorkSpace(workspaceSize);
|
||||
}
|
||||
|
||||
ret = aclnnFlashAttentionScoreGradV2(workspaceAddr, workspaceSize, executor, aclstream);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnFlashAttentionScoreGradV2 failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
|
||||
syncRun();
|
||||
return;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
#pragma once
|
||||
#include "utils.h"
|
||||
#include "base_op.h"
|
||||
|
||||
namespace jittor
|
||||
{
|
||||
class FlashAttentionOpRunner : public BaseOpRunner
|
||||
{
|
||||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
FlashAttentionOpRunner();
|
||||
};
|
||||
|
||||
class FlashAttentionBackwardOpRunner : public BaseOpRunner
|
||||
{
|
||||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
FlashAttentionBackwardOpRunner();
|
||||
};
|
||||
|
||||
}
|
|
@ -0,0 +1,85 @@
|
|||
import os
|
||||
from jittor_utils import env_or_try_find
|
||||
import jittor_utils
|
||||
import ctypes
|
||||
import glob
|
||||
import jittor.compiler as compiler
|
||||
import jittor as jt
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
from typing import Union
|
||||
from collections.abc import Sequence, Iterable
|
||||
|
||||
|
||||
def flip_cmd(name: str,
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
||||
|
||||
cuda_header = '''
|
||||
#include "acl/aclops/aclops.h"
|
||||
'''
|
||||
outputs_ = []
|
||||
if outputs is not None:
|
||||
outputs_ = outputs
|
||||
else:
|
||||
assert output_dtypes is not None
|
||||
assert output_shapes is not None
|
||||
assert len(output_dtypes) == len(output_shapes)
|
||||
for i in range(len(output_shapes)):
|
||||
outputs_.append(jt.empty(output_shapes[i], output_dtypes[i]))
|
||||
input_code = ''
|
||||
for i in range(len(inputs)):
|
||||
input_code += f"op.add(in{i}, true);\n"
|
||||
|
||||
output_code = ''
|
||||
for i in range(len(outputs_)):
|
||||
output_code += f"op.add(out{i}, false);\n"
|
||||
return jt.code(outputs=outputs_,
|
||||
inputs=inputs,
|
||||
cuda_header=attr_header + cuda_header,
|
||||
cuda_src=f"""
|
||||
|
||||
// aclop
|
||||
{name}OpRunner op;
|
||||
{input_code}
|
||||
{output_code}
|
||||
{attr_code}
|
||||
op.run();""")
|
||||
|
||||
|
||||
class FlipACL(jt.Function):
|
||||
|
||||
def __init__(self):
|
||||
super(FlipACL, self).__init__()
|
||||
|
||||
def execute(self, input, dim):
|
||||
if type(dim) is tuple:
|
||||
dim = list(dim)
|
||||
if type(dim) is not list:
|
||||
dim = [dim]
|
||||
attr_code = f"""
|
||||
op.jt_name = "flip";
|
||||
ReduceAttr *attr = new ReduceAttr();
|
||||
attr->axes = {{{', '.join(map(str, (list(dim))))}}};
|
||||
attr->prod_dim = {len(dim)};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
self.attr_code = attr_code
|
||||
result = flip_cmd("Flip", [input],
|
||||
output_dtypes=[input.dtype],
|
||||
output_shapes=[input.shape],
|
||||
attr_code=self.attr_code)[0]
|
||||
return result
|
||||
|
||||
def grad(self, grad_output):
|
||||
grad_input = flip_cmd("Flip", [grad_output],
|
||||
output_dtypes=[grad_output.dtype],
|
||||
output_shapes=[grad_output.shape],
|
||||
attr_code=self.attr_code)[0]
|
||||
return grad_input
|
|
@ -0,0 +1,58 @@
|
|||
#pragma once
|
||||
#include <acl/acl.h>
|
||||
#include <acl/acl_op_compiler.h>
|
||||
#include <Python.h>
|
||||
#include <pystate.h>
|
||||
#include <algorithm>
|
||||
#include <queue>
|
||||
#include <set>
|
||||
#include "common.h"
|
||||
#include "op.h"
|
||||
#include "acl_jittor.h"
|
||||
#include "ops/random_op.h"
|
||||
#include "ops/reduce_op.h"
|
||||
#include "ops/binary_op.h"
|
||||
#include "ops/broadcast_to_op.h"
|
||||
#include "ops/transpose_op.h"
|
||||
#include "ops/array_op.h"
|
||||
#include "ops/code_op.h"
|
||||
#include "fused_op.h"
|
||||
#include "ops/unary_op.h"
|
||||
#include "ops/ternary_op.h"
|
||||
#include "executor.h"
|
||||
#include "misc/cuda_flags.h"
|
||||
#include "mem/allocator.h"
|
||||
#include "op_compiler.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "opt/tuner_manager.h"
|
||||
#include "utils/str_utils.h"
|
||||
#include "aclnn/aclnn.h"
|
||||
#include "flip_op_acl.h"
|
||||
|
||||
namespace jittor
|
||||
{
|
||||
FlipOpRunner::FlipOpRunner() : BaseOpRunner("Flip")
|
||||
{
|
||||
}
|
||||
|
||||
void FlipOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||
{
|
||||
auto attr = dynamic_cast<ReduceAttr *>(op_attr.get());
|
||||
auto dim = aclCreateIntArray(attr->axes.data(), attr->axes.size());
|
||||
ret = aclnnFlipGetWorkspaceSize(inputTensors[0], dim, outputTensors[0], &workspaceSize, &executor);
|
||||
|
||||
checkRet(ret);
|
||||
|
||||
if (workspaceSize > 0)
|
||||
{
|
||||
mallocWorkSpace(workspaceSize);
|
||||
}
|
||||
|
||||
ret = aclnnFlip(workspaceAddr, workspaceSize, executor, aclstream);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnFlip failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
|
||||
syncRun();
|
||||
return;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,16 @@
|
|||
#pragma once
|
||||
#include "utils.h"
|
||||
#include "base_op.h"
|
||||
|
||||
namespace jittor
|
||||
{
|
||||
class FlipOpRunner : public BaseOpRunner
|
||||
{
|
||||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
FlipOpRunner();
|
||||
};
|
||||
}
|
|
@ -0,0 +1,70 @@
|
|||
import os
|
||||
from jittor_utils import env_or_try_find
|
||||
import jittor_utils
|
||||
import ctypes
|
||||
import glob
|
||||
import jittor.compiler as compiler
|
||||
import jittor as jt
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
from typing import Union
|
||||
from collections.abc import Sequence, Iterable
|
||||
|
||||
|
||||
def floor_cmd(name: str,
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
||||
|
||||
cuda_header = '''
|
||||
#include "acl/aclops/aclops.h"
|
||||
'''
|
||||
outputs_ = []
|
||||
if outputs is not None:
|
||||
outputs_ = outputs
|
||||
else:
|
||||
assert output_dtypes is not None
|
||||
assert output_shapes is not None
|
||||
assert len(output_dtypes) == len(output_shapes)
|
||||
for i in range(len(output_shapes)):
|
||||
outputs_.append(jt.empty(output_shapes[i], output_dtypes[i]))
|
||||
input_code = ''
|
||||
for i in range(len(inputs)):
|
||||
input_code += f"op.add(in{i}, true);\n"
|
||||
|
||||
output_code = ''
|
||||
for i in range(len(outputs_)):
|
||||
output_code += f"op.add(out{i}, false);\n"
|
||||
return jt.code(outputs=outputs_,
|
||||
inputs=inputs,
|
||||
cuda_header=attr_header + cuda_header,
|
||||
cuda_src=f"""
|
||||
|
||||
// aclop
|
||||
{name}OpRunner op;
|
||||
{input_code}
|
||||
{output_code}
|
||||
{attr_code}
|
||||
op.run();""")
|
||||
|
||||
|
||||
class FloorIntACL(jt.Function):
|
||||
|
||||
def __init__(self):
|
||||
super(FloorIntACL, self).__init__()
|
||||
|
||||
def execute(self, input):
|
||||
self.shape = input.shape
|
||||
result = floor_cmd("Floor", [input],
|
||||
output_dtypes=[input.dtype],
|
||||
output_shapes=[input.shape],
|
||||
attr_code="op.jt_name=\"floor\";")[0]
|
||||
return result
|
||||
|
||||
def grad(self, grad_output):
|
||||
return jt.zeros(self.shape, dtype=grad_output.dtype)
|
|
@ -0,0 +1,56 @@
|
|||
#pragma once
|
||||
#include <acl/acl.h>
|
||||
#include <acl/acl_op_compiler.h>
|
||||
#include <Python.h>
|
||||
#include <pystate.h>
|
||||
#include <algorithm>
|
||||
#include <queue>
|
||||
#include <set>
|
||||
#include "common.h"
|
||||
#include "op.h"
|
||||
#include "acl_jittor.h"
|
||||
#include "ops/random_op.h"
|
||||
#include "ops/reduce_op.h"
|
||||
#include "ops/binary_op.h"
|
||||
#include "ops/broadcast_to_op.h"
|
||||
#include "ops/transpose_op.h"
|
||||
#include "ops/array_op.h"
|
||||
#include "ops/code_op.h"
|
||||
#include "fused_op.h"
|
||||
#include "ops/unary_op.h"
|
||||
#include "ops/ternary_op.h"
|
||||
#include "executor.h"
|
||||
#include "misc/cuda_flags.h"
|
||||
#include "mem/allocator.h"
|
||||
#include "op_compiler.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "opt/tuner_manager.h"
|
||||
#include "utils/str_utils.h"
|
||||
#include "aclnn/aclnn.h"
|
||||
#include "floor_op_acl.h"
|
||||
|
||||
namespace jittor
|
||||
{
|
||||
FloorOpRunner::FloorOpRunner() : BaseOpRunner("Floor")
|
||||
{
|
||||
}
|
||||
|
||||
void FloorOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||
{
|
||||
ret = aclnnFloorGetWorkspaceSize(inputTensors[0], outputTensors[0], &workspaceSize, &executor);
|
||||
|
||||
checkRet(ret);
|
||||
|
||||
if (workspaceSize > 0)
|
||||
{
|
||||
mallocWorkSpace(workspaceSize);
|
||||
}
|
||||
|
||||
ret = aclnnFloor(workspaceAddr, workspaceSize, executor, aclstream);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnFloor failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
|
||||
syncRun();
|
||||
return;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,16 @@
|
|||
#pragma once
|
||||
#include "utils.h"
|
||||
#include "base_op.h"
|
||||
|
||||
namespace jittor
|
||||
{
|
||||
class FloorOpRunner : public BaseOpRunner
|
||||
{
|
||||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
FloorOpRunner();
|
||||
};
|
||||
}
|
|
@ -0,0 +1,126 @@
|
|||
import os
|
||||
from jittor_utils import env_or_try_find
|
||||
import jittor_utils
|
||||
import ctypes
|
||||
import glob
|
||||
import jittor.compiler as compiler
|
||||
import jittor as jt
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
from typing import Union
|
||||
from collections.abc import Sequence, Iterable
|
||||
|
||||
|
||||
def gather_scatter_cmd(name: str,
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
||||
|
||||
cuda_header = '''
|
||||
#include "acl/aclops/aclops.h"
|
||||
'''
|
||||
outputs_ = []
|
||||
if outputs is not None:
|
||||
outputs_ = outputs
|
||||
else:
|
||||
assert output_dtypes is not None
|
||||
assert output_shapes is not None
|
||||
assert len(output_dtypes) == len(output_shapes)
|
||||
for i in range(len(output_shapes)):
|
||||
outputs_.append(jt.empty(output_shapes[i], output_dtypes[i]))
|
||||
input_code = ''
|
||||
for i in range(len(inputs)):
|
||||
input_code += f"op.add(in{i}, true);\n"
|
||||
|
||||
output_code = ''
|
||||
for i in range(len(outputs_)):
|
||||
output_code += f"op.add(out{i}, false);\n"
|
||||
return jt.code(outputs=outputs_,
|
||||
inputs=inputs,
|
||||
cuda_header=attr_header + cuda_header,
|
||||
cuda_src=f"""
|
||||
|
||||
// aclop
|
||||
{name}OpRunner op;
|
||||
{input_code}
|
||||
{output_code}
|
||||
{attr_code}
|
||||
op.run();""")
|
||||
|
||||
|
||||
class GatherACL(jt.Function):
|
||||
|
||||
def __init__(self):
|
||||
super(GatherACL, self).__init__()
|
||||
|
||||
def execute(self, input, dim, index):
|
||||
self.dim = dim
|
||||
self.index = index
|
||||
attr_code = f"""
|
||||
op.jt_name = "gather";
|
||||
GatherAttr *attr = new GatherAttr();
|
||||
attr->dim = {dim};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
result = gather_scatter_cmd("Gather", [input, index],
|
||||
output_dtypes=[input.dtype],
|
||||
output_shapes=[index.shape],
|
||||
attr_code=attr_code)[0]
|
||||
return result
|
||||
|
||||
def grad(self, grad_output):
|
||||
tmp = jt.zeros(self.index.shape, dtype=grad_output.dtype)
|
||||
attr_code = f"""
|
||||
op.jt_name = "scatter";
|
||||
ScatterAttr *attr = new ScatterAttr();
|
||||
attr->axis = {self.dim};
|
||||
attr->reduction = {1};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
grad_input = gather_scatter_cmd("Scatter",
|
||||
[tmp, self.index, grad_output],
|
||||
output_dtypes=[grad_output.dtype],
|
||||
output_shapes=[tmp.shape],
|
||||
attr_code=attr_code)[0]
|
||||
return grad_input
|
||||
|
||||
|
||||
class ScatterACL(jt.Function):
|
||||
|
||||
def __init__(self):
|
||||
super(ScatterACL, self).__init__()
|
||||
|
||||
def execute(self, input, dim, index, src, reduce='void'):
|
||||
self.dim = dim
|
||||
self.index = index
|
||||
self.reduce = reduce
|
||||
attr_code = f"""
|
||||
op.jt_name = "scatter";
|
||||
ScatterAttr *attr = new ScatterAttr();
|
||||
attr->axis = {dim};
|
||||
attr->reduction = {1 if reduce == 'add' else 2 if reduce == 'mul' else 0};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
result = gather_scatter_cmd("Scatter", [input, self.index, src],
|
||||
output_dtypes=[input.dtype],
|
||||
output_shapes=[input.shape],
|
||||
attr_code=attr_code)[0]
|
||||
return result
|
||||
|
||||
def grad(self, grad_output):
|
||||
attr_code = f"""
|
||||
op.jt_name = "gather";
|
||||
GatherAttr *attr = new GatherAttr();
|
||||
attr->dim = {self.dim};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
grad_input = gather_scatter_cmd("Gather", [grad_output, self.index],
|
||||
output_dtypes=[grad_output.dtype],
|
||||
output_shapes=[self.index.shape],
|
||||
attr_code=attr_code)[0]
|
||||
return grad_output, None, None, grad_input
|
|
@ -0,0 +1,80 @@
|
|||
#pragma once
|
||||
#include <acl/acl.h>
|
||||
#include <acl/acl_op_compiler.h>
|
||||
#include <Python.h>
|
||||
#include <pystate.h>
|
||||
#include <algorithm>
|
||||
#include <queue>
|
||||
#include <set>
|
||||
#include "common.h"
|
||||
#include "op.h"
|
||||
#include "acl_jittor.h"
|
||||
#include "ops/random_op.h"
|
||||
#include "ops/reduce_op.h"
|
||||
#include "ops/binary_op.h"
|
||||
#include "ops/broadcast_to_op.h"
|
||||
#include "ops/transpose_op.h"
|
||||
#include "ops/array_op.h"
|
||||
#include "ops/code_op.h"
|
||||
#include "fused_op.h"
|
||||
#include "ops/unary_op.h"
|
||||
#include "ops/ternary_op.h"
|
||||
#include "executor.h"
|
||||
#include "misc/cuda_flags.h"
|
||||
#include "mem/allocator.h"
|
||||
#include "op_compiler.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "opt/tuner_manager.h"
|
||||
#include "utils/str_utils.h"
|
||||
#include "aclnn/aclnn.h"
|
||||
#include "gather_scatter_op_acl.h"
|
||||
|
||||
namespace jittor
|
||||
{
|
||||
GatherOpRunner::GatherOpRunner() : BaseOpRunner("Gather")
|
||||
{
|
||||
}
|
||||
|
||||
void GatherOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||
{
|
||||
auto attr = dynamic_cast<GatherAttr *>(op_attr.get());
|
||||
ret = aclnnGatherGetWorkspaceSize(inputTensors[0], attr->dim, inputTensors[1], outputTensors[0], &workspaceSize, &executor);
|
||||
|
||||
checkRet(ret);
|
||||
|
||||
if (workspaceSize > 0)
|
||||
{
|
||||
mallocWorkSpace(workspaceSize);
|
||||
}
|
||||
|
||||
ret = aclnnGather(workspaceAddr, workspaceSize, executor, aclstream);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnGather failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
|
||||
syncRun();
|
||||
return;
|
||||
}
|
||||
|
||||
ScatterOpRunner::ScatterOpRunner() : BaseOpRunner("Scatter")
|
||||
{
|
||||
}
|
||||
|
||||
void ScatterOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||
{
|
||||
auto attr = dynamic_cast<ScatterAttr *>(op_attr.get());
|
||||
ret = aclnnScatterGetWorkspaceSize(inputTensors[0], attr->axis, inputTensors[1], inputTensors[2], attr->reduction, outputTensors[0], &workspaceSize, &executor);
|
||||
|
||||
checkRet(ret);
|
||||
|
||||
if (workspaceSize > 0)
|
||||
{
|
||||
mallocWorkSpace(workspaceSize);
|
||||
}
|
||||
|
||||
ret = aclnnScatter(workspaceAddr, workspaceSize, executor, aclstream);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnScatter failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
|
||||
syncRun();
|
||||
return;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,26 @@
|
|||
#pragma once
|
||||
#include "utils.h"
|
||||
#include "base_op.h"
|
||||
|
||||
namespace jittor
|
||||
{
|
||||
class GatherOpRunner : public BaseOpRunner
|
||||
{
|
||||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
GatherOpRunner();
|
||||
};
|
||||
|
||||
class ScatterOpRunner : public BaseOpRunner
|
||||
{
|
||||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
ScatterOpRunner();
|
||||
};
|
||||
}
|
|
@ -0,0 +1,419 @@
|
|||
import os
|
||||
from jittor_utils import env_or_try_find
|
||||
import jittor_utils
|
||||
import ctypes
|
||||
import glob
|
||||
import jittor.compiler as compiler
|
||||
import jittor as jt
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
from typing import Union
|
||||
from collections.abc import Sequence, Iterable
|
||||
|
||||
|
||||
def getitem_cmd(name: str,
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
||||
|
||||
cuda_header = '''
|
||||
#include "acl/aclops/aclops.h"
|
||||
'''
|
||||
outputs_ = []
|
||||
if outputs is not None:
|
||||
outputs_ = outputs
|
||||
else:
|
||||
assert output_dtypes is not None
|
||||
assert output_shapes is not None
|
||||
assert len(output_dtypes) == len(output_shapes)
|
||||
for i in range(len(output_shapes)):
|
||||
outputs_.append(jt.empty(output_shapes[i], output_dtypes[i]))
|
||||
input_code = ''
|
||||
for i in range(len(inputs)):
|
||||
input_code += f"op.add(in{i}, true);\n"
|
||||
|
||||
output_code = ''
|
||||
for i in range(len(outputs_)):
|
||||
output_code += f"op.add(out{i}, false);\n"
|
||||
return jt.code(outputs=outputs_,
|
||||
inputs=inputs,
|
||||
cuda_header=attr_header + cuda_header,
|
||||
cuda_src=f"""
|
||||
|
||||
// aclop
|
||||
{name}OpRunner op;
|
||||
{input_code}
|
||||
{output_code}
|
||||
{attr_code}
|
||||
op.run();""")
|
||||
|
||||
|
||||
def getitem_forward(name: str,
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None,
|
||||
extra_data: dict = {}):
|
||||
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
||||
|
||||
cuda_header = '''
|
||||
#include "acl/aclops/aclops.h"
|
||||
'''
|
||||
outputs_ = []
|
||||
if outputs is not None:
|
||||
outputs_ = outputs
|
||||
else:
|
||||
assert output_dtypes is not None
|
||||
assert output_shapes is not None
|
||||
assert len(output_dtypes) == len(output_shapes)
|
||||
for i in range(len(output_shapes)):
|
||||
outputs_.append(jt.empty(output_shapes[i], output_dtypes[i]))
|
||||
input_code = ''
|
||||
for i in range(len(inputs)):
|
||||
input_code += f"op.add(in{i}, true);\n"
|
||||
|
||||
output_code = ''
|
||||
for i in range(len(outputs_)):
|
||||
output_code += f"op.add(out{i}, false);\n"
|
||||
return jt.code(outputs=outputs_,
|
||||
inputs=inputs,
|
||||
cuda_header=attr_header + cuda_header,
|
||||
cuda_src=f"""
|
||||
// aclop
|
||||
{name}OpRunner op;
|
||||
{input_code}
|
||||
op.add(out0, false);
|
||||
{attr_code}
|
||||
op.run();""",
|
||||
data=extra_data)
|
||||
|
||||
|
||||
def caculate_shape(tensors):
|
||||
if isinstance(tensors, jt.Var):
|
||||
# tensors = tensors[0]
|
||||
return tensors.shape
|
||||
elif isinstance(tensors, (int, float)):
|
||||
return []
|
||||
elif isinstance(tensors, (list, tuple)):
|
||||
# return [caculate_shape(tensor) for tensor in tensors]
|
||||
sub_shape = caculate_shape(tensors[0])
|
||||
return [len(tensors)] + sub_shape
|
||||
else:
|
||||
assert False, f"not implemented for {type(tensors)}"
|
||||
|
||||
|
||||
def can_broadcast_and_shape(shape1, shape2):
|
||||
"""
|
||||
检查两个张量是否可以广播,并返回广播后的形状。
|
||||
|
||||
参数:
|
||||
- shape1: 第一个张量的形状(tuple 或 list)
|
||||
- shape2: 第二个张量的形状(tuple 或 list)
|
||||
|
||||
返回:
|
||||
- can_broadcast: 布尔值,表示是否可以广播
|
||||
- broadcast_shape: 如果可以广播,返回广播后的形状;否则返回 None
|
||||
"""
|
||||
# 将形状转换为元组,以防输入是列表
|
||||
shape1 = tuple(shape1)
|
||||
shape2 = tuple(shape2)
|
||||
|
||||
# 使两个形状的长度一致,通过在前面补1
|
||||
len1, len2 = len(shape1), len(shape2)
|
||||
if len1 < len2:
|
||||
shape1 = (1, ) * (len2 - len1) + shape1
|
||||
elif len2 < len1:
|
||||
shape2 = (1, ) * (len1 - len2) + shape2
|
||||
|
||||
broadcast_shape = []
|
||||
|
||||
# 从最后一维开始检查每一维度
|
||||
for dim1, dim2 in zip(shape1, shape2):
|
||||
if dim1 == dim2:
|
||||
broadcast_shape.append(dim1)
|
||||
elif dim1 == 1:
|
||||
broadcast_shape.append(dim2)
|
||||
elif dim2 == 1:
|
||||
broadcast_shape.append(dim1)
|
||||
else:
|
||||
# 如果在某一维度上不兼容,则不能广播
|
||||
return False, None
|
||||
|
||||
return True, tuple(broadcast_shape)
|
||||
|
||||
|
||||
class GetItemACL(jt.Function):
|
||||
|
||||
def __init__(self):
|
||||
self.type_ = 'notype'
|
||||
|
||||
def stride(self, x, dim):
|
||||
stride = 1
|
||||
for i in range(dim + 1, len(x.shape)):
|
||||
stride *= x.shape[i]
|
||||
return stride
|
||||
|
||||
def execute(self, x, slices, return_x=None):
|
||||
if isinstance(slices, jt.Var) and slices.dtype == 'bool':
|
||||
# assert False, "not support bool type now"
|
||||
#TODO:优化
|
||||
assert x.shape == slices.shape, "shape not match"
|
||||
output_len = slices.sum().item()
|
||||
# output = jt.empty((output_len,),dtype=x.dtype)
|
||||
x_len = x.numel()
|
||||
output = jt.empty((x_len), dtype=x.dtype)
|
||||
outputs = [output]
|
||||
inputs = [x, slices]
|
||||
# print(inputs,outputs)
|
||||
# print(output.shape)
|
||||
self.mask = slices
|
||||
self.type_ = 'mask'
|
||||
attr_code = f"""
|
||||
op.jt_name = "maskedselect";
|
||||
"""
|
||||
result = getitem_cmd("MaskedSelect",
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
attr_code=attr_code)[0]
|
||||
result = result[:output_len]
|
||||
result.sync()
|
||||
return result
|
||||
self.x_shape = x.shape
|
||||
if not isinstance(slices, tuple):
|
||||
slices = (slices, )
|
||||
slices = list(slices)
|
||||
for i, s in enumerate(slices):
|
||||
if isinstance(s, int) and s < 0:
|
||||
slices[i] = s + x.shape[i]
|
||||
slices = tuple(slices)
|
||||
slices_list = list(slices)
|
||||
# if not isinstance(slices[0], slice):
|
||||
#check slices contains slice type
|
||||
contains_slice = False
|
||||
for s in slices:
|
||||
if not isinstance(s, jt.Var) and (isinstance(s, slice)
|
||||
or s == Ellipsis):
|
||||
contains_slice = True
|
||||
break
|
||||
if not contains_slice:
|
||||
indices = []
|
||||
output_shape = []
|
||||
slices_len = len(slices)
|
||||
boardcast_shape = caculate_shape(slices_list[0])
|
||||
for ii in range(1, len(slices)):
|
||||
dd, boardcast_shape = can_broadcast_and_shape(
|
||||
boardcast_shape, caculate_shape(slices_list[ii]))
|
||||
assert dd is True, "can not broadcast"
|
||||
output_shape = boardcast_shape
|
||||
output_shape += x.shape[slices_len:]
|
||||
if output_shape == []:
|
||||
output_shape = [1]
|
||||
for ii in slices:
|
||||
indices.append(jt.Var(ii).int32())
|
||||
if isinstance(slices[0],
|
||||
jt.Var) or isinstance(slices[0], int) or isinstance(
|
||||
slices[0], list) or isinstance(slices[0], tuple):
|
||||
self.indices = indices
|
||||
inputs = [x] + indices
|
||||
attr_code = f"""
|
||||
op.jt_name = "index";
|
||||
"""
|
||||
self.type_ = 'index'
|
||||
result = getitem_cmd("Index",
|
||||
inputs=inputs,
|
||||
output_dtypes=[x.dtype],
|
||||
output_shapes=[output_shape],
|
||||
attr_code=attr_code)[0]
|
||||
result.sync()
|
||||
return result
|
||||
assert contains_slice, "slice type error"
|
||||
x_dim = len(x.shape)
|
||||
slices = list(slices)
|
||||
for s in slices:
|
||||
if not isinstance(s, jt.Var) and s == Ellipsis:
|
||||
slices = slices[:slices.index(s)] + [
|
||||
slice(None, None, None)
|
||||
] * (x_dim - len(slices) + 1) + slices[slices.index(s) + 1:]
|
||||
break
|
||||
slices = tuple(slices)
|
||||
|
||||
if len(slices) < x_dim:
|
||||
slices += (slice(None, None, None), ) * (x_dim - len(slices))
|
||||
inputs = [x]
|
||||
sizes = []
|
||||
begins = []
|
||||
ends = []
|
||||
steps = []
|
||||
dims = []
|
||||
squeeze_dims = []
|
||||
|
||||
extra_data = {}
|
||||
if len(slices):
|
||||
extra_data["a"] = len(slices)
|
||||
for dim, s in enumerate(slices):
|
||||
if isinstance(s, int):
|
||||
s = slice(s, s + 1, 1)
|
||||
squeeze_dims.append(dim)
|
||||
if isinstance(s, jt.Var):
|
||||
assert False, "jt.Var not supported"
|
||||
start, stop, step = s.indices(x.size(dim))
|
||||
size = (stop - start - 1) // step + 1
|
||||
# stride = self.stride(x, dim) * step
|
||||
sizes.append(size)
|
||||
extra_data[str(dim * 3)] = start
|
||||
extra_data[str(dim * 3 + 1)] = stop
|
||||
extra_data[str(dim * 3 + 2)] = step
|
||||
|
||||
steps.append(step)
|
||||
begins.append(start)
|
||||
ends.append(stop)
|
||||
dims.append(dim)
|
||||
else:
|
||||
extra_data["a"] = -1
|
||||
sizes = [1]
|
||||
steps = [1]
|
||||
self.type_ = 'slicev2'
|
||||
# for backward
|
||||
self.begins = begins
|
||||
self.ends = ends
|
||||
self.steps = steps
|
||||
self.dims = dims
|
||||
|
||||
self.slices = slices
|
||||
attr_code = """
|
||||
op.jt_name = "slicev2";
|
||||
StrideAttr *attr = new StrideAttr();
|
||||
|
||||
int slice_dim = data["a"];
|
||||
|
||||
if(slice_dim == -1) {
|
||||
attr->begins = {};
|
||||
attr->ends = {};
|
||||
attr->steps = {1};
|
||||
attr->axes = {};
|
||||
} else {
|
||||
vector<long int> begins;
|
||||
vector<long int> ends;
|
||||
vector<long int> steps;
|
||||
vector<long int> dims;
|
||||
for(int dim = 0; dim < slice_dim; dim++) {
|
||||
dims.push_back(dim);
|
||||
begins.push_back(data[std::to_string(dim*3)]);
|
||||
ends.push_back(data[std::to_string(dim*3+1)]);
|
||||
steps.push_back(data[std::to_string(dim*3+2)]);
|
||||
}
|
||||
attr->begins = begins;
|
||||
attr->ends = ends;
|
||||
attr->steps = steps;
|
||||
attr->axes = dims;
|
||||
}
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
result = getitem_forward("SliceV2",
|
||||
inputs,
|
||||
output_dtypes=[x.dtype],
|
||||
output_shapes=[jt.empty(sizes).shape],
|
||||
attr_code=attr_code,
|
||||
extra_data=extra_data)[0]
|
||||
self.squeeze_dims = squeeze_dims
|
||||
for dim in squeeze_dims[::-1]:
|
||||
result = jt.squeeze(result, dim)
|
||||
result.sync()
|
||||
return result
|
||||
|
||||
def grad(self, grad_output):
|
||||
if self.type_ == 'index':
|
||||
indices = self.indices
|
||||
inputs = [grad_output] + indices
|
||||
attr_code = f"""
|
||||
op.jt_name = "indexputimplaccumulate";
|
||||
"""
|
||||
outputs = [jt.zeros(self.x_shape, dtype=grad_output.dtype)]
|
||||
# breakpoint()
|
||||
result = getitem_cmd("IndexPutImplAccumulate",
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
attr_code=attr_code)[0]
|
||||
result.sync()
|
||||
return result, None
|
||||
elif self.type_ == 'slicev2':
|
||||
begins = self.begins
|
||||
ends = self.ends
|
||||
steps = self.steps
|
||||
dims = self.dims
|
||||
slices = self.slices
|
||||
#注意前向的维数可能会被压缩,所以这里要还原
|
||||
for dim in self.squeeze_dims:
|
||||
grad_output = jt.unsqueeze(grad_output, dim)
|
||||
#适配华为奇怪的要求,最后一个维度的step必须是1
|
||||
expand_dim = False
|
||||
if isinstance(slices[-1], slice):
|
||||
if slices[-1].step is not None and slices[-1].step != 1:
|
||||
slices = slices + (slice(None, None, None), )
|
||||
expand_dim = True
|
||||
elif isinstance(slices[-1], int):
|
||||
#注意最后一个维度是数字
|
||||
slices = list(slices)
|
||||
slices[-1] = slice(slices[-1], slices[-1] + 1, 1)
|
||||
slices = tuple(slices)
|
||||
slices = slices + (slice(None, None, None), )
|
||||
expand_dim = True
|
||||
else:
|
||||
assert False, "not supported"
|
||||
# x = x.unsqueeze(-1)
|
||||
if expand_dim:
|
||||
grad_output = grad_output.unsqueeze(-1)
|
||||
self.x_shape = self.x_shape + (1, )
|
||||
sizes = []
|
||||
begins = []
|
||||
ends = []
|
||||
steps = []
|
||||
dims = []
|
||||
for dim, s in enumerate(slices):
|
||||
if isinstance(s, int):
|
||||
s = slice(s, s + 1, 1)
|
||||
# squeeze_dims.append(dim)
|
||||
if isinstance(s, jt.Var):
|
||||
assert False, "jt.Var not supported"
|
||||
start, stop, step = s.indices(self.x_shape[dim])
|
||||
size = (stop - start - 1) // step + 1
|
||||
# stride = self.stride(x, dim) * step
|
||||
sizes.append(size)
|
||||
steps.append(step)
|
||||
begins.append(start)
|
||||
ends.append(stop)
|
||||
dims.append(dim)
|
||||
if not sizes:
|
||||
sizes = [1]
|
||||
steps = [1]
|
||||
attr_code = f"""
|
||||
op.jt_name = "stridedsliceassignv2";
|
||||
StrideAttr *attr = new StrideAttr();
|
||||
attr->begins = {{ {", ".join(map(str, begins))} }};
|
||||
attr->ends = {{ {", ".join(map(str, ends))} }};
|
||||
attr->steps = {{ {", ".join(map(str, steps))} }};
|
||||
attr->axes = {{ {", ".join(map(str, dims))} }};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
inputs = [grad_output]
|
||||
outputs = [jt.zeros(self.x_shape, dtype=grad_output.dtype)]
|
||||
result = getitem_cmd("StridedSliceAssignV2",
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
attr_code=attr_code)[0]
|
||||
result.sync()
|
||||
if expand_dim:
|
||||
result = result.squeeze(-1)
|
||||
return result, None
|
||||
elif self.type_ == 'mask':
|
||||
return self.mask.float()
|
||||
pass
|
||||
else:
|
||||
assert False, f"grad not implemented for {self.type_}"
|
|
@ -0,0 +1,165 @@
|
|||
#pragma once
|
||||
#include <acl/acl.h>
|
||||
#include <acl/acl_op_compiler.h>
|
||||
#include <Python.h>
|
||||
#include <pystate.h>
|
||||
#include <algorithm>
|
||||
#include <queue>
|
||||
#include <set>
|
||||
#include "common.h"
|
||||
#include "op.h"
|
||||
#include "acl_jittor.h"
|
||||
#include "ops/random_op.h"
|
||||
#include "ops/reduce_op.h"
|
||||
#include "ops/binary_op.h"
|
||||
#include "ops/broadcast_to_op.h"
|
||||
#include "ops/transpose_op.h"
|
||||
#include "ops/array_op.h"
|
||||
#include "ops/code_op.h"
|
||||
#include "fused_op.h"
|
||||
#include "ops/unary_op.h"
|
||||
#include "ops/ternary_op.h"
|
||||
#include "executor.h"
|
||||
#include "misc/cuda_flags.h"
|
||||
#include "mem/allocator.h"
|
||||
#include "op_compiler.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "opt/tuner_manager.h"
|
||||
#include "utils/str_utils.h"
|
||||
#include "aclnn/aclnn.h"
|
||||
#include "getitem_op_acl.h"
|
||||
|
||||
namespace jittor
|
||||
{
|
||||
MaskedSelectOpRunner::MaskedSelectOpRunner() : BaseOpRunner("MaskedSelect")
|
||||
{
|
||||
}
|
||||
|
||||
void MaskedSelectOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||
{
|
||||
ret = aclnnMaskedSelectGetWorkspaceSize(inputTensors[0], inputTensors[1], outputTensors[0], &workspaceSize, &executor);
|
||||
|
||||
checkRet(ret);
|
||||
|
||||
if (workspaceSize > 0)
|
||||
{
|
||||
mallocWorkSpace(workspaceSize);
|
||||
}
|
||||
|
||||
ret = aclnnMaskedSelect(workspaceAddr, workspaceSize, executor, aclstream);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnMaskedSelect failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
|
||||
syncRun();
|
||||
return;
|
||||
}
|
||||
|
||||
IndexOpRunner::IndexOpRunner() : BaseOpRunner("Index")
|
||||
{
|
||||
}
|
||||
|
||||
void IndexOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||
{
|
||||
auto input_num = in_.size();
|
||||
auto indexTensorList = aclCreateTensorList(&inputTensors[1], input_num - 1);
|
||||
ret = aclnnIndexGetWorkspaceSize(inputTensors[0], indexTensorList, outputTensors[0], &workspaceSize, &executor);
|
||||
|
||||
checkRet(ret);
|
||||
|
||||
if (workspaceSize > 0)
|
||||
{
|
||||
mallocWorkSpace(workspaceSize);
|
||||
}
|
||||
|
||||
ret = aclnnIndex(workspaceAddr, workspaceSize, executor, aclstream);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnIndex failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
|
||||
syncRun();
|
||||
return;
|
||||
}
|
||||
|
||||
SliceV2OpRunner::SliceV2OpRunner() : BaseOpRunner("SliceV2")
|
||||
{
|
||||
}
|
||||
|
||||
void SliceV2OpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||
{
|
||||
auto attr = dynamic_cast<StrideAttr *>(op_attr.get());
|
||||
auto begins = aclCreateIntArray(attr->begins.data(), attr->begins.size());
|
||||
auto ends = aclCreateIntArray(attr->ends.data(), attr->ends.size());
|
||||
auto steps = aclCreateIntArray(attr->steps.data(), attr->steps.size());
|
||||
auto axes = aclCreateIntArray(attr->axes.data(), attr->axes.size());
|
||||
ret = aclnnSliceV2GetWorkspaceSize(inputTensors[0], begins, ends, axes, steps, outputTensors[0], &workspaceSize, &executor);
|
||||
|
||||
checkRet(ret);
|
||||
|
||||
if (workspaceSize > 0)
|
||||
{
|
||||
mallocWorkSpace(workspaceSize);
|
||||
}
|
||||
|
||||
ret = aclnnSliceV2(workspaceAddr, workspaceSize, executor, aclstream);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnSliceV2 failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
|
||||
syncRun();
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
IndexPutImplAccumulateOpRunner::IndexPutImplAccumulateOpRunner() : BaseOpRunner("IndexPutImplAccumulate")
|
||||
{
|
||||
}
|
||||
|
||||
void IndexPutImplAccumulateOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||
{
|
||||
auto input_num = in_.size();
|
||||
std::vector<aclTensor *> indexTensorList = {};
|
||||
for (int i = 1; i < input_num; i++)
|
||||
{
|
||||
indexTensorList.push_back(inputTensors[i]);
|
||||
}
|
||||
auto indexTensorListInput = aclCreateTensorList(&indexTensorList[0], input_num - 1);
|
||||
ret = aclnnIndexPutImplGetWorkspaceSize(outputTensors[0], indexTensorListInput, inputTensors[0], true, true, &workspaceSize, &executor);
|
||||
|
||||
checkRet(ret);
|
||||
|
||||
if (workspaceSize > 0)
|
||||
{
|
||||
mallocWorkSpace(workspaceSize);
|
||||
}
|
||||
|
||||
ret = aclnnIndexPutImpl(workspaceAddr, workspaceSize, executor, aclstream);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnIndexPutImpl failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
|
||||
syncRun();
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
StridedSliceAssignV2OpRunner::StridedSliceAssignV2OpRunner() : BaseOpRunner("StridedSliceAssignV2")
|
||||
{
|
||||
}
|
||||
|
||||
void StridedSliceAssignV2OpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||
{
|
||||
auto attr = dynamic_cast<StrideAttr *>(op_attr.get());
|
||||
auto begins = aclCreateIntArray(attr->begins.data(), attr->begins.size());
|
||||
auto ends = aclCreateIntArray(attr->ends.data(), attr->ends.size());
|
||||
auto steps = aclCreateIntArray(attr->steps.data(), attr->steps.size());
|
||||
auto axes = aclCreateIntArray(attr->axes.data(), attr->axes.size());
|
||||
ret = aclnnStridedSliceAssignV2GetWorkspaceSize(outputTensors[0], inputTensors[0], begins, ends, steps, axes, &workspaceSize, &executor);
|
||||
|
||||
checkRet(ret);
|
||||
|
||||
if (workspaceSize > 0)
|
||||
{
|
||||
mallocWorkSpace(workspaceSize);
|
||||
}
|
||||
|
||||
ret = aclnnStridedSliceAssignV2(workspaceAddr, workspaceSize, executor, aclstream);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnStridedSliceAssignV2 failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
|
||||
syncRun();
|
||||
|
||||
return;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,57 @@
|
|||
#pragma once
|
||||
#include "utils.h"
|
||||
#include "base_op.h"
|
||||
|
||||
namespace jittor
|
||||
{
|
||||
class MaskedSelectOpRunner : public BaseOpRunner
|
||||
{
|
||||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
MaskedSelectOpRunner();
|
||||
};
|
||||
|
||||
class IndexOpRunner : public BaseOpRunner
|
||||
{
|
||||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
IndexOpRunner();
|
||||
};
|
||||
|
||||
class SliceV2OpRunner : public BaseOpRunner
|
||||
{
|
||||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
SliceV2OpRunner();
|
||||
};
|
||||
|
||||
class IndexPutImplAccumulateOpRunner : public BaseOpRunner
|
||||
{
|
||||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
IndexPutImplAccumulateOpRunner();
|
||||
};
|
||||
|
||||
class StridedSliceAssignV2OpRunner : public BaseOpRunner
|
||||
{
|
||||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
StridedSliceAssignV2OpRunner();
|
||||
};
|
||||
|
||||
}
|
|
@ -0,0 +1,107 @@
|
|||
import os
|
||||
from jittor_utils import env_or_try_find
|
||||
import jittor_utils
|
||||
import ctypes
|
||||
import glob
|
||||
import jittor.compiler as compiler
|
||||
import jittor as jt
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
from typing import Union
|
||||
from collections.abc import Sequence, Iterable
|
||||
|
||||
|
||||
def range_forward(name: str,
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None,
|
||||
extra_data: dict = {}):
|
||||
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
||||
|
||||
cuda_header = '''
|
||||
#include "acl/aclops/aclops.h"
|
||||
'''
|
||||
outputs_ = []
|
||||
if outputs is not None:
|
||||
outputs_ = outputs
|
||||
else:
|
||||
assert output_dtypes is not None
|
||||
assert output_shapes is not None
|
||||
assert len(output_dtypes) == len(output_shapes)
|
||||
for i in range(len(output_shapes)):
|
||||
outputs_.append(jt.empty(output_shapes[i], output_dtypes[i]))
|
||||
input_code = ''
|
||||
for i in range(len(inputs)):
|
||||
input_code += f"op.add(in{i}, true);\n"
|
||||
|
||||
output_code = ''
|
||||
for i in range(len(outputs_)):
|
||||
output_code += f"op.add(out{i}, false);\n"
|
||||
return jt.code(outputs=outputs_,
|
||||
inputs=inputs,
|
||||
cuda_header=attr_header + cuda_header,
|
||||
cuda_src=f"""
|
||||
// aclop
|
||||
{name}OpRunner op;
|
||||
{input_code}
|
||||
op.add(out0, false);
|
||||
{attr_code}
|
||||
op.run();""",
|
||||
data=extra_data)
|
||||
|
||||
|
||||
class IndexACL(jt.Function):
|
||||
|
||||
def __init__(self):
|
||||
super(IndexACL, self).__init__()
|
||||
|
||||
def execute(self, inshape: list, dim=None, dtype="int32"):
|
||||
# zeros a tensor, shape is inshape, dtype is dtype
|
||||
dim_input = dim
|
||||
if dim == None:
|
||||
dim = [i for i in range(len(inshape))]
|
||||
elif type(dim) == int:
|
||||
dim = [dim]
|
||||
results = []
|
||||
extra_data = {}
|
||||
extra_data["dim_count"] = len(dim)
|
||||
|
||||
for i, d in enumerate(dim):
|
||||
max_len = inshape[d]
|
||||
|
||||
extra_data[f"dim_{i}_start"] = 0
|
||||
extra_data[f"dim_{i}_end"] = max_len
|
||||
extra_data[f"dim_{i}_step"] = 1
|
||||
|
||||
tmp = jt.zeros(max_len, dtype=dtype)
|
||||
range_attr_code = f"""
|
||||
op.jt_name = "range";
|
||||
RangeAttr *attr = new RangeAttr();
|
||||
attr->start = data["dim_{i}_start"];
|
||||
attr->end = data["dim_{i}_end"];
|
||||
attr->step = data["dim_{i}_step"];
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
result = range_forward("Range", [],
|
||||
output_dtypes=[tmp.dtype],
|
||||
output_shapes=[tmp.shape],
|
||||
attr_code=range_attr_code,
|
||||
extra_data=extra_data)[0]
|
||||
broadcast_dims = list(range(len(inshape)))
|
||||
broadcast_dims.remove(d)
|
||||
result = jt.broadcast(result, shape=inshape, dims=broadcast_dims)
|
||||
results.append(result)
|
||||
|
||||
if len(results) != 1 or dim_input == None:
|
||||
return tuple(results)
|
||||
elif len(results) == 1 and dim_input != None:
|
||||
return results[0]
|
||||
else:
|
||||
return results
|
||||
|
||||
def grad(self, grad_output):
|
||||
return grad_output
|
|
@ -0,0 +1,72 @@
|
|||
#pragma once
|
||||
#include <acl/acl.h>
|
||||
#include <acl/acl_op_compiler.h>
|
||||
#include <Python.h>
|
||||
#include <pystate.h>
|
||||
#include <algorithm>
|
||||
#include <queue>
|
||||
#include <set>
|
||||
#include "common.h"
|
||||
#include "op.h"
|
||||
#include "acl_jittor.h"
|
||||
#include "ops/random_op.h"
|
||||
#include "ops/reduce_op.h"
|
||||
#include "ops/binary_op.h"
|
||||
#include "ops/broadcast_to_op.h"
|
||||
#include "ops/transpose_op.h"
|
||||
#include "ops/array_op.h"
|
||||
#include "ops/code_op.h"
|
||||
#include "fused_op.h"
|
||||
#include "ops/unary_op.h"
|
||||
#include "ops/ternary_op.h"
|
||||
#include "executor.h"
|
||||
#include "misc/cuda_flags.h"
|
||||
#include "mem/allocator.h"
|
||||
#include "op_compiler.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "opt/tuner_manager.h"
|
||||
#include "utils/str_utils.h"
|
||||
#include "aclnn/aclnn.h"
|
||||
#include "index_op_acl.h"
|
||||
|
||||
namespace jittor
|
||||
{
|
||||
RangeOpRunner::RangeOpRunner() : BaseOpRunner("Range")
|
||||
{
|
||||
}
|
||||
|
||||
void RangeOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||
{
|
||||
aclScalar *start = nullptr;
|
||||
aclScalar *end = nullptr;
|
||||
aclScalar *step = nullptr;
|
||||
|
||||
auto attr = dynamic_cast<RangeAttr *>(op_attr.get());
|
||||
int64_t startValue = attr->start;
|
||||
int64_t endValue = attr->end;
|
||||
int64_t stepValue = attr->step;
|
||||
start = aclCreateScalar(&startValue, aclDataType::ACL_INT64);
|
||||
end = aclCreateScalar(&endValue, aclDataType::ACL_INT64);
|
||||
step = aclCreateScalar(&stepValue, aclDataType::ACL_INT64);
|
||||
|
||||
ret = aclnnRangeGetWorkspaceSize(start, end, step, outputTensors[0], &workspaceSize, &executor);
|
||||
|
||||
checkRet(ret);
|
||||
|
||||
if (workspaceSize > 0)
|
||||
{
|
||||
mallocWorkSpace(workspaceSize);
|
||||
}
|
||||
|
||||
ret = aclnnRange(workspaceAddr, workspaceSize, executor, aclstream);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnRange failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
|
||||
syncRun();
|
||||
|
||||
aclDestroyScalar(start);
|
||||
aclDestroyScalar(end);
|
||||
aclDestroyScalar(step);
|
||||
return;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,17 @@
|
|||
#pragma once
|
||||
#include "utils.h"
|
||||
#include "base_op.h"
|
||||
|
||||
namespace jittor
|
||||
{
|
||||
class RangeOpRunner : public BaseOpRunner
|
||||
{
|
||||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
RangeOpRunner();
|
||||
};
|
||||
|
||||
}
|
|
@ -0,0 +1,130 @@
|
|||
import os
|
||||
from jittor_utils import env_or_try_find
|
||||
import jittor_utils
|
||||
import ctypes
|
||||
import glob
|
||||
import jittor.compiler as compiler
|
||||
import jittor as jt
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
from typing import Union
|
||||
from collections.abc import Sequence, Iterable
|
||||
|
||||
|
||||
def matmul_forward(name: str,
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None,
|
||||
extra_data: dict = {}):
|
||||
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
||||
|
||||
cuda_header = '''
|
||||
#include "acl/aclops/aclops.h"
|
||||
'''
|
||||
outputs_ = []
|
||||
if outputs is not None:
|
||||
outputs_ = outputs
|
||||
else:
|
||||
assert output_dtypes is not None
|
||||
assert output_shapes is not None
|
||||
assert len(output_dtypes) == len(output_shapes)
|
||||
for i in range(len(output_shapes)):
|
||||
outputs_.append(jt.empty(output_shapes[i], output_dtypes[i]))
|
||||
input_code = ''
|
||||
for i in range(len(inputs)):
|
||||
input_code += f"op.add(in{i}, true);\n"
|
||||
|
||||
return jt.code(outputs=outputs_,
|
||||
inputs=inputs,
|
||||
cuda_header=attr_header + cuda_header,
|
||||
cuda_src=f"""
|
||||
// aclop
|
||||
MatMulOpRunner op;
|
||||
{input_code}
|
||||
op.add(out0, false);
|
||||
{attr_code}
|
||||
op.run();""",
|
||||
data=extra_data)
|
||||
|
||||
|
||||
class MatmulACL(jt.Function):
|
||||
|
||||
def __init__(self, trans_x2=False):
|
||||
super(MatmulACL, self).__init__()
|
||||
self.trans_x2 = trans_x2
|
||||
|
||||
def execute(self, x1, x2):
|
||||
self.input = [x1, x2]
|
||||
result = matmul_forward(
|
||||
"MatMul", [x1, x2],
|
||||
output_dtypes=[x1.dtype],
|
||||
output_shapes=[
|
||||
x1.shape[:-1] +
|
||||
x2.shape[-2:-1] if self.trans_x2 else x1.shape[:-1] +
|
||||
x2.shape[-1:]
|
||||
],
|
||||
attr_code="op.jt_name=\"matmul_trans_1\";"
|
||||
if self.trans_x2 else "op.jt_name=\"matmul\";")[0]
|
||||
return result
|
||||
|
||||
def grad(self, grad_output):
|
||||
x1, x2 = self.input
|
||||
if len(x1) != len(x2):
|
||||
reshape_grad_x2 = True
|
||||
else:
|
||||
reshape_grad_x2 = False
|
||||
grad_x1 = matmul_forward(
|
||||
"MatMul", [grad_output, x2],
|
||||
output_dtypes=[x1.dtype],
|
||||
output_shapes=[
|
||||
grad_output.shape[:-1] + x2.shape[-2:-1] if not self.trans_x2
|
||||
else grad_output.shape[:-1] + x2.shape[-1:]
|
||||
],
|
||||
attr_code="op.jt_name=\"matmul_trans_1\";"
|
||||
if not self.trans_x2 else "op.jt_name=\"matmul\";")[0]
|
||||
|
||||
if self.trans_x2:
|
||||
if reshape_grad_x2:
|
||||
output_shape = grad_output.shape[1:-2] + grad_output.shape[
|
||||
-1:] + x1.shape[-1:]
|
||||
grad_x2 = matmul_forward(
|
||||
"MatMul", [
|
||||
grad_output.reshape(-1, grad_output.shape[-1]),
|
||||
x1.reshape(-1, x1.shape[-1])
|
||||
],
|
||||
output_dtypes=[x2.dtype],
|
||||
output_shapes=[output_shape],
|
||||
attr_code="op.jt_name=\"matmul_trans_0\";")[0]
|
||||
else:
|
||||
output_shape = grad_output.shape[:-2] + grad_output.shape[
|
||||
-1:] + x1.shape[-1:]
|
||||
grad_x2 = matmul_forward(
|
||||
"MatMul", [grad_output, x1],
|
||||
output_dtypes=[x2.dtype],
|
||||
output_shapes=[output_shape],
|
||||
attr_code="op.jt_name=\"matmul_trans_0\";")[0]
|
||||
else:
|
||||
if reshape_grad_x2:
|
||||
output_shape = x1.shape[1:-2] + x1.shape[
|
||||
-1:] + grad_output.shape[-1:]
|
||||
grad_x2 = matmul_forward(
|
||||
"MatMul", [
|
||||
x1.reshape(-1, x1.shape[-1]),
|
||||
grad_output.reshape(-1, grad_output.shape[-1])
|
||||
],
|
||||
output_dtypes=[x2.dtype],
|
||||
output_shapes=[output_shape],
|
||||
attr_code="op.jt_name=\"matmul_trans_0\";")[0]
|
||||
else:
|
||||
output_shape = x1.shape[:-2] + x1.shape[
|
||||
-1:] + grad_output.shape[-1:]
|
||||
grad_x2 = matmul_forward(
|
||||
"MatMul", [x1, grad_output],
|
||||
output_dtypes=[x2.dtype],
|
||||
output_shapes=[output_shape],
|
||||
attr_code="op.jt_name=\"matmul_trans_0\";")[0]
|
||||
return grad_x1, grad_x2
|
|
@ -0,0 +1,77 @@
|
|||
#pragma once
|
||||
#include <acl/acl.h>
|
||||
#include <acl/acl_op_compiler.h>
|
||||
#include <Python.h>
|
||||
#include <pystate.h>
|
||||
#include <algorithm>
|
||||
#include <queue>
|
||||
#include <set>
|
||||
#include "common.h"
|
||||
#include "op.h"
|
||||
#include "acl_jittor.h"
|
||||
#include "ops/random_op.h"
|
||||
#include "ops/reduce_op.h"
|
||||
#include "ops/binary_op.h"
|
||||
#include "ops/broadcast_to_op.h"
|
||||
#include "ops/transpose_op.h"
|
||||
#include "ops/array_op.h"
|
||||
#include "ops/code_op.h"
|
||||
#include "fused_op.h"
|
||||
#include "ops/unary_op.h"
|
||||
#include "ops/ternary_op.h"
|
||||
#include "executor.h"
|
||||
#include "misc/cuda_flags.h"
|
||||
#include "mem/allocator.h"
|
||||
#include "op_compiler.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "opt/tuner_manager.h"
|
||||
#include "utils/str_utils.h"
|
||||
#include "aclnn/aclnn.h"
|
||||
#include "matmul_op_acl.h"
|
||||
|
||||
namespace jittor
|
||||
{
|
||||
MatMulOpRunner::MatMulOpRunner() : BaseOpRunner("MatMul")
|
||||
{
|
||||
}
|
||||
void MatMulOpRunner::setupInputDesc()
|
||||
{
|
||||
auto input_num = in_.size();
|
||||
for (int input_idx = 0; input_idx < input_num; input_idx++)
|
||||
{
|
||||
std::vector<int64_t> shape;
|
||||
for (int j = 0; j < in_[input_idx]->shape.size(); j++)
|
||||
{
|
||||
shape.push_back(in_[input_idx]->shape[j]);
|
||||
}
|
||||
inputShapes.push_back(shape);
|
||||
}
|
||||
for (int idx = 0; idx < input_num; idx++)
|
||||
{
|
||||
inputTensors.push_back(nullptr);
|
||||
if ((jt_name == "matmul_trans_1" && idx == 1) || (jt_name == "matmul_trans_0" && idx == 0))
|
||||
{
|
||||
auto ret = CreateFakeTransAclTensor(inputShapes[idx], in_[idx]->mem_ptr, in_[idx]->size, get_dtype(in_[idx]->dtype()), &inputTensors[idx], use_nchw);
|
||||
CHECK_RET(ret == ACL_SUCCESS, return);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto ret = CreateAclTensor(inputShapes[idx], in_[idx]->mem_ptr, in_[idx]->size, get_dtype(in_[idx]->dtype()), &inputTensors[idx], use_nchw);
|
||||
CHECK_RET(ret == ACL_SUCCESS, return);
|
||||
}
|
||||
}
|
||||
}
|
||||
void MatMulOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||
{
|
||||
|
||||
ret = aclnnMatmulGetWorkspaceSize(inputTensors[0], inputTensors[1], outputTensors[0], 1, &workspaceSize, &executor);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnMatmulGetWorkspaceSize failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
if (workspaceSize > 0)
|
||||
{
|
||||
mallocWorkSpace(workspaceSize);
|
||||
}
|
||||
ret = aclnnMatmul(workspaceAddr, workspaceSize, executor, aclstream);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnMatmul failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
syncRun();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,17 @@
|
|||
#pragma once
|
||||
#include "utils.h"
|
||||
#include "base_op.h"
|
||||
|
||||
namespace jittor
|
||||
{
|
||||
class MatMulOpRunner : public BaseOpRunner
|
||||
{
|
||||
|
||||
protected:
|
||||
void setupInputDesc() override;
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
MatMulOpRunner();
|
||||
};
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue