diff --git a/doc/build_doc.sh b/doc/build_doc.sh index d8c53288..cfa9d84d 100755 --- a/doc/build_doc.sh +++ b/doc/build_doc.sh @@ -16,3 +16,4 @@ echo "[jittor path] $jittor_path" export PYTHONPATH=$jittor_path/python cd $bpath sphinx-autobuild -b html source build -H 0.0.0.0 -p 8890 + diff --git a/python/jittor/compatibility/__init__.py b/python/jittor/compatibility/__init__.py new file mode 100644 index 00000000..94d2e40b --- /dev/null +++ b/python/jittor/compatibility/__init__.py @@ -0,0 +1,430 @@ +# import os +# os.environ["FIX_TORCH_ERROR"] = "0" + +# import jittor as jt +# from jittor import * +# from typing import Tuple + +# org_int = int = type(1) +# org_float = float = type(1.0) +# org_bool = bool = type(True) + +# import jtorch.compiler + +# import jtorch_core +# from jtorch_core import * + +# device.__reduce__ = lambda self: (device, (self.type,)) +# device.__module__ = "jtorch" +# jt.jittor_core.device = device + +# def handle_dtype(args, kw, dtype): +# def convert(x): +# if isinstance(x, jt.Var): +# return x.cast(dtype) +# return x +# if dtype is not None: +# if args is not None: +# if isinstance(args, (tuple,list)): +# args = [ convert(a) for a in args ] +# else: +# args = convert(x) +# if kw is not None: +# kw = { k:convert(v) for k,v in kw.items() } +# return args, kw + +# def get_args_names(func): +# import inspect +# spec = inspect.getfullargspec(func) +# return spec[0] + spec[4] + +# def wrapper(func): +# has_dtype = False +# if hasattr(func, "__code__"): +# has_dtype = "dtype" in get_args_names(func) +# def inner(*args, **kw): +# requires_grad = None +# dtype = None +# if "requires_grad" in kw: +# requires_grad = kw["requires_grad"] +# del kw["requires_grad"] +# if not has_dtype and "dtype" in kw: +# dtype = kw["dtype"] +# del kw["dtype"] +# if "device" in kw: +# del kw["device"] +# if 'pin_memory' in kw: +# del kw['pin_memory'] +# args, kw = handle_dtype(args, kw, dtype) +# ret = func(*args, **kw) +# if isinstance(ret, jt.Var): +# if requires_grad is not None: +# ret.requires_grad = requires_grad +# if dtype is not None: +# ret.astype(dtype) +# return ret +# return inner + + +# import inspect +# _wrapper_keys = set(["shape", "start", "size"]) +# _wrapper_keys.add("x") +# for k,v in list(globals().items()): +# if callable(v) and not isinstance(v, type): +# try: +# spec = inspect.getfullargspec(v) +# args_name = spec[0] +# if len(args_name) and args_name[0] in _wrapper_keys: +# globals()[k] = wrapper(v) +# elif spec.varargs in _wrapper_keys: +# globals()[k] = wrapper(v) +# except: +# pass + +# def empty(*size, dtype=jt.float32, device=None, requires_grad=False): +# if len(size) == 1 and not isinstance(size[0], org_int): +# size = size[0] +# return jt.empty(size, dtype) + +# Tensor = Var + +# Tensor.backward = lambda x: jtorch_core.backward(x) +# Tensor.grad = property(grad_get, grad_set, grad_del) +# Tensor.retains_grad = property(retain_grad_get, retain_grad_set) +# def retain_grad(x:Tensor, value:bool=True): +# x.retains_grad = value +# return value +# Tensor.retain_grad = retain_grad + +# Tensor.dim = lambda self: self.ndim +# Tensor.ndimension = lambda self: self.ndim +# Tensor.nelement = lambda self: self.numel() +# Tensor.cuda = lambda self: self +# def device_get(x:Tensor): +# return device("cpu") if not jt.has_cuda or not jt.flags.use_cuda else device("cuda") +# Tensor.device = property(device_get) + +# def argmax(x: Var, dim=None, keepdim: bool = False): +# return jt.argmax(x, dim, keepdim)[0] +# Tensor.argmax = argmax + +# def tensor_type(x: Var, dtype=None, **kwargs): +# if dtype: +# return x.astype(dtype) +# else: +# return x.dtype +# Tensor.type = tensor_type + +# def is_floating_point(x: Var): +# return "float" in str(x.dtype) +# Tensor.is_floating_point = is_floating_point + +# from . import autograd +# from .autograd import * + +# def tensor(data, *, dtype=None, device=None, requires_grad=False, pin_memory=False): +# if isinstance(data,list): +# data_list = [] +# check = True +# for p in data: +# if isinstance(p, Tensor) and p.numel()==1: +# data_list.append(p.item()) +# elif isinstance(p, (org_int,org_float)): +# data_list.append(p) +# else: +# check = False +# break +# if check: +# data = data_list +# return wrapper(array)(data, dtype=dtype, device=device, requires_grad=requires_grad, pin_memory=pin_memory) + +# # tensor = wrapper(array) +# from_numpy = wrapper(array) +# strided = None + +# def mod_zero_grad(self): +# for p in self.parameters(): +# p.grad = None +# Module.zero_grad = mod_zero_grad + +# class ModuleMisc: +# def parameters(self): +# return iter(super().parameters()) + +# def load_state_dict(self, state_dict, strict=False): +# return super().load_state_dict(state_dict) + +# def to(self, device=None,dtype=None): +# ''' do nothing but return its self''' +# return self +# def register_parameter(self,name,data): +# self.name = data + +# def buffers(self): +# for _, buf in self.named_buffers(): +# yield buf + + +# def make_module(cls): +# class TMod(ModuleMisc, cls): +# def __init__(self, *args, **kw): +# dtype = None +# if "dtype" in kw: +# dtype = kw["dtype"] +# del kw["dtype"] +# self._dtype = dtype +# with jt.flag_scope(th_mode=0): +# if "device" in kw: +# del kw["device"] +# super().__init__(*args, **kw) +# for k,v in self.__dict__.items(): +# if not k.startswith("_") and isinstance(v, Var) \ +# and v.requires_grad: +# v.retain_grad() +# if dtype is not None and isinstance(v, Var): +# v.assign(v.cast(dtype)) +# def __call__(self, *args, **kw): +# args, kw = handle_dtype(args, kw, self._dtype) +# # if forward is override by user, call forward +# if self.__class__.forward is not TMod.forward: +# return self.forward(*args, **kw) +# return self.execute(*args, **kw) +# def forward(self, *args, **kw): +# args, kw = handle_dtype(args, kw, self._dtype) +# return self.execute(*args, **kw) + +# @property +# def training(self): +# if not hasattr(self, "is_train"): +# self.is_train = True +# return self.is_train +# @training.setter +# def training(self, value): +# self.is_train = value + +# TMod.__name__ = cls.__name__ +# return TMod + +# import jtorch.cuda +# import jtorch.nn +# from jtorch.nn import Module, Parameter +# import jtorch.optim + +# from jtorch.utils.dtype import Dtype, get_string_dtype + +# def frombuffer(buffer: bytearray, +# *, +# dtype: Dtype, +# count: int = -1, +# offset: int = 0, +# requires_grad: bool = True) -> Tensor: +# dtype = get_string_dtype(dtype) +# tensor = jt.array(np.frombuffer(buffer, dtype, count=count, offset=offset)) +# if requires_grad and tensor.dtype.is_float(): +# tensor.requires_grad = True +# return tensor + +# def conflict_wrapper(origin_func, new_func): +# def wrapper(*args, **kw): +# if jt.flags.th_mode: +# return new_func(*args, **kw) +# else: +# return origin_func(*args, **kw) +# return wrapper + +# def min(*args, **kw): +# dim = None +# if len(args) >= 2 and isinstance(args[1], org_int): +# dim = args[1] +# elif "dim" in kw and isinstance(kw["dim"], org_int): +# dim = kw["dim"] +# if dim is not None: +# k, v = jt.argmin(*args, **kw) +# return v, k +# elif len(args) == 2 and isinstance(args[1], jt.Var): +# return jt.minimum(args[0], args[1]) +# else: +# return jt.min(*args, **kw) +# Tensor.min = conflict_wrapper(jt.min, min) + +# def max(*args, **kw): +# dim = None +# if "dim" in kw: +# x = kw["dim"] +# if len(args) >= 2 and isinstance(args[1], org_int): +# dim = args[1] +# elif "dim" in kw and isinstance(kw["dim"], org_int): +# dim = kw["dim"] +# if dim is not None: +# k, v = jt.argmax(*args, **kw) +# return v, k +# elif len(args) == 2 and isinstance(args[1], jt.Var): +# return jt.maximum(args[0], args[1]) +# else: +# return jt.max(*args, **kw) +# Tensor.max = conflict_wrapper(jt.max, max) + +# def argsort(*args, **kw): +# k, v = jt.argsort(*args, **kw) +# return k +# Tensor.argsort = conflict_wrapper(jt.argsort, argsort) + +# LongTensor = jt.int64 +# FloatTensor = jt.float +# HalfTensor = jt.float16 +# BoolTensor = jt.bool +# IntTensor = jt.int32 + +# class JDType: +# def __init__(self, func, str): +# self.func = func +# self.str = str +# self.__name__ = str.split(".")[-1] +# def __call__(self, *args, **kw): +# return self.func(*args, **kw) +# def __str__(self): +# return self.str +# def is_floating_point(self): +# return "float" in str(self.str) + +# int8 = JDType(jt.int8, "torch.int8") +# int16 = JDType(jt.int16, "torch.int16") +# int = int32 = JDType(jt.int32, "torch.int32") +# long = int64 = JDType(jt.int64, "torch.int64") + +# half = float16 = JDType(jt.float16, "torch.float16") +# float = float32 = JDType(jt.float32, "torch.float32") +# double = float64 = JDType(jt.float64, "torch.float64") +# bfloat16 = "bfloat16" # TODO +# complex64 = "complex64" # TODO +# complex128 = "complex128" # TODO +# def get_JDtype(dtype): +# if dtype=='float32' or dtype == jt.float32: +# return float32 +# elif dtype=='float64' or dtype == jt.float64: +# return float64 +# elif dtype=='float16' or dtype == jt.float16: +# return float16 +# elif dtype=='int32' or dtype == jt.int32: +# return int32 +# elif dtype=='int64' or dtype == jt.int64: +# return int64 +# elif dtype=='int16' or dtype == jt.int16: +# return int16 +# elif dtype=='int8' or dtype == jt.int8: +# return int8 +# else: +# raise Exception("dtype {} not supported".format(dtype)) + +# def load(path,**kwargs): +# def _to_jittor(data): +# if isinstance(data,dict): +# return {k:_to_jittor(d) for k,d in data.items()} +# if isinstance(data,list): +# return [_to_jittor(d) for d in data] +# if isinstance(data,np.ndarray): +# return jt.array(data) +# return data +# data = jt.load(path) + +# return _to_jittor(data) + +# def is_tensor(x): +# return isinstance(x, Tensor) + +# manual_seed = jt.set_global_seed +# jt.flags.amp_level = 3 +# Size = jt.NanoVector + +# class Generator: +# def __init__(self,*args,**kw) -> None: +# self.seed = None +# def manual_seed(self,seed): +# self.seed = seed + + + +# from . import fx + + +# _default_type = "float32" + +# def get_default_dtype(): +# return _default_type +# def set_default_dtype(dtype): +# global _default_type +# _default_type = dtype + +# dtype = JDType + +# def div(x,y,rounding_mode="floor"): +# assert rounding_mode == "floor" +# z = (x / y) +# if rounding_mode == "floor": +# z = z.floor() +# if x.dtype == "int32" and (isinstance(y,org_int) or y.dtype == "int32"): +# z = z.int32() +# return z + + +# def randn(*args,**kw): +# wrap_randn = wrapper(jt.randn) +# generator = kw.get('generator',None) +# kw.pop('generator',None) +# if 'layout' in kw: +# del kw['layout'] +# if generator is not None and generator.seed is not None: +# jt.set_global_seed(generator.seed) +# return wrap_randn(*args,**kw) + +# def rand(*args,**kw): +# print("rand") +# wrap_rand = wrapper(jt.rand) +# generator = kw.get('generator',None) +# kw.pop('generator',None) +# if 'layout' in kw: +# del kw['layout'] +# if generator is not None and generator.seed is not None: +# jt.set_global_seed(generator.seed) +# return wrap_rand(*args,**kw) + + + +# def set_default_tensor_type(t: type or str): +# if isinstance(t, str): +# info = t.split(".") +# if len(info) == 3 and info[1] == 'cuda': +# jt.flags.use_cuda = 1 +# #TODO: type + + +# def clamp(x, min=None, max=None): +# return jt.clamp(x, min, max) + + +# def to(x,*args,**kw): +# device = None +# if len(args) == 1: +# device = args[0] +# if isinstance(device, jt.NanoString) or callable(device): +# return jt.to(x,*args,**kw) +# if 'cpu' in str(device): +# args = [] +# device = kw.get("device",None) +# if 'cpu' in str(device): +# kw.pop('device',None) +# print("to cpu") +# # print(kw) +# return jt.to(x,*args,**kw) +# Tensor.to = conflict_wrapper(jt.to, to) + +# mm = wrapper(jt.matmul) + +# def _data_get(x): +# return x + +# def _data_set(x, value): +# x.assign(value) + +# Tensor.data = property(_data_get, _data_set) +# Tensor.layout = None \ No newline at end of file diff --git a/python/jittor/compatibility/autograd.py b/python/jittor/compatibility/autograd.py new file mode 100644 index 00000000..5ed88dde --- /dev/null +++ b/python/jittor/compatibility/autograd.py @@ -0,0 +1,134 @@ +import jittor as jt +from jittor import Var +from collections.abc import Sequence, Mapping + +Variable = Var + +class FunctionContext: + def save_for_backward(self, *args): + self.saved_tensors = args + +class Function: + ''' Function Module for customized backward operations + +Example 1 (Function can have multiple input and multiple output, and user +can store value for backward computation):: + + import jtorch + from jtorch import Function + + class MyFunc(Function): + @staticmethod + def forward(self, x, y): + self.x = x + self.y = y + return x*y, x/y + + @staticmethod + def backward(self, grad0, grad1): + return grad0 * self.y, grad1 * self.x + + a = jtorch.array(3.0) + a.requires_grad = True + b = jtorch.array(4.0) + b.requires_grad = True + func = MyFunc.apply + c,d = func(a, b) + (c+d*3).backward() + assert a.grad.data == 4 + assert b.grad.data == 9 + +Example 2(Function can return None for no gradiant, and gradiant +can also be None):: + + import jtorch + from jtorch import Function + + class MyFunc(Function): + @staticmethod + def forward(self, x, y): + self.x = x + self.y = y + return x*y, x/y + + @staticmethod + def backward(self, grad0, grad1): + assert grad1 is None + return grad0 * self.y, None + a = jt.array(3.0) + a.requires_grad = True + b = jt.array(4.0) + b.requires_grad = True + func = MyFunc.apply + c,d = func(a, b) + d.stop_grad() + da, db = jt.grad(c+d*3, [a, b]) + assert da.data == 4 + assert db.data == 0 + + ''' + def __call__(self, *args): + backup = args + args = list(args) + taped_inputs = [] + taped_outputs = [] + input_mask = [-1] * len(args) + for i,v in enumerate(args): + if isinstance(v, Var): + if v.is_stop_grad(): + # -2 in input_mask represents it is stop_grad + input_mask[i] = -2 + continue + v = v.tape() + input_mask[i] = len(taped_inputs) + args[i] = v + taped_inputs.append(v) + ctx = FunctionContext() + ori_res = self.forward(ctx, *args) + # ori_res = self.execute(*args) + if not isinstance(ori_res, Sequence): + res = [ori_res] + else: + res = list(ori_res) + output_mask = [-1] * len(res) + for i,v in enumerate(res): + if isinstance(v, Var): + v = v.tape() + output_mask[i] = len(taped_outputs) + res[i] = v + taped_outputs.append(v) + ctx.input_mask = input_mask + ctx.output_mask = output_mask + # tape output and input together so + # backward treat them as one operator + jt.tape_together(taped_inputs, taped_outputs, + lambda *args: self._grad(ctx, self, *args)) + if isinstance(ori_res, Sequence): + return res + else: + return res[0] + + @staticmethod + def _grad(ctx, func, *args): + new_args = ( (args[i] if i>=0 else None) for i in ctx.output_mask ) + ret = func.backward(ctx, *new_args) + if not isinstance(ret, Sequence): + ret = (ret,) + new_ret = [] + for i, r in enumerate(ret): + j = ctx.input_mask[i] + if j<0: + # -2 in input_mask represents it is stop_grad + assert r is None or j==-2, f"{type(self)}'s {i}-th returned grad should be None, "\ + "because the input value is not jittor variable." + else: + new_ret.append(r) + return new_ret + + def dfs(self, parents, k, callback, callback_leave=None): + pass + + @classmethod + def apply(cls, *args, **kw): + func = cls() + return func(*args, **kw) diff --git a/python/jittor/compatibility/compiler.py b/python/jittor/compatibility/compiler.py new file mode 100644 index 00000000..77bab138 --- /dev/null +++ b/python/jittor/compatibility/compiler.py @@ -0,0 +1,39 @@ +import jittor as jt +import jittor_utils +import glob +import os +from jittor import pyjt_compiler +import sys +from jittor_utils import lock + + +jtorch_path = os.path.dirname(__file__) +cache_path = os.path.join(jt.compiler.cache_path, "jtorch") +# os.makedirs(cache_path, exist_ok=True) +os.makedirs(os.path.join(cache_path, "gen"), exist_ok=True) + +with lock.lock_scope(): + pyjt_gen_src = pyjt_compiler.compile(cache_path, jtorch_path) + +ext_args = 'c[cu]' if jt.has_cuda else 'cc' +files = glob.glob(jtorch_path+"/src/**/*."+ext_args, recursive=True) +files += pyjt_gen_src +cc_flags = " -I\""+os.path.join(jtorch_path, "src")+"\" " +if os.environ.get("use_data_o", "1") == "1": + files += glob.glob(jtorch_path+"/src/**/*.o", recursive=True) + files = [f for f in files if "__data__" not in f] + + +with lock.lock_scope(): + jt.compiler.compile( + jt.compiler.cc_path, + jt.compiler.cc_flags+jt.compiler.opt_flags+ cc_flags, + files, + "jtorch_core"+jt.compiler.extension_suffix, + obj_dirname="jtorch_objs") + + +with jittor_utils.import_scope(jt.compiler.import_flags): + import jtorch_core as core + +jt.flags.th_mode = 1 diff --git a/python/jittor/compatibility/cuda.py b/python/jittor/compatibility/cuda.py new file mode 100644 index 00000000..75665c7c --- /dev/null +++ b/python/jittor/compatibility/cuda.py @@ -0,0 +1,64 @@ +import jittor as jt +import jtorch + +def is_available(): + return jt.has_cuda + +def device_count(): + return int(jt.has_cuda) + +def set_device(device=None): + pass + +def get_rng_state(device=None): + pass + +def current_device(): + return jtorch.device("cuda") + +def mem_get_info(i): + return ("75GB",) + + +class Generator: + def __init__(self): + pass + + def set_state(self, state): + self.state = state + +default_generators = [Generator()] +_lazy_call = lambda func: func() +device = None + +LongTensor = jt.int64 +FloatTensor = jt.float +HalfTensor = jt.float16 +BoolTensor = jt.bool + +manual_seed = jt.set_global_seed +manual_seed_all = jt.set_global_seed + +def synchronize(): + jt.sync_all(True) + +class Event: + pass + +class Stream: + pass + +from typing import Any + +from .gradscaler import GradScaler + +class autocast: + def __init__(self,**kwargs): + pass + + def __enter__(self,): + pass + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): + pass + diff --git a/python/jittor/compatibility/distributed.py b/python/jittor/compatibility/distributed.py new file mode 100644 index 00000000..e39f559a --- /dev/null +++ b/python/jittor/compatibility/distributed.py @@ -0,0 +1,53 @@ +import datetime +from enum import Enum +import jittor as jt + + +class DistributedDataParallel: + def __new__(cls, model): + return model + +def is_initialized(): + return True + +def get_rank(group=None): + return 0 + +def get_world_size(group=None): + return 1 + +def get_backend(group=None): + return "nccl" + +def new_group(ranks=None, timeout=datetime.timedelta(seconds=1800), backend=None, pg_options=None): + return 1 + +def barrier(): + pass + +def is_available(): + return True + +def is_built(): + return True + +class ReduceOp: + SUM = 0 + +class GroupMember: + WORLD = 0 + +class ProcessGroup: + pass + +class Join: + pass + +dist_backend = Enum("dist_backend", ("GLOO", "MPI", "NCCL")) +_backend = dist_backend.NCCL + +def is_mpi_available(): + return jt.in_mpi + +def DistributedDataParallel(model, *args, **kw): + return model diff --git a/python/jittor/compatibility/distributions.py b/python/jittor/compatibility/distributions.py new file mode 100644 index 00000000..a98dfe29 --- /dev/null +++ b/python/jittor/compatibility/distributions.py @@ -0,0 +1,15 @@ +import jittor as jt + +class RelaxedBernoulli: + def __init__(self, temperature, probs=None, logits=None): + self.temperature = temperature + self.probs = probs + self.logits = logits + + def rsample(self): + noise = jt.rand_like(self.logits) + eps = 1e-20 + noise = jt.clamp(noise, eps, 1.0 - eps) + logit_noise = jt.log(noise) - jt.log(1 - noise) + sample = (self.logits + logit_noise) / self.temperature + return jt.sigmoid(sample) diff --git a/python/jittor/compatibility/fft/__init__.py b/python/jittor/compatibility/fft/__init__.py new file mode 100644 index 00000000..7a89fc9c --- /dev/null +++ b/python/jittor/compatibility/fft/__init__.py @@ -0,0 +1,5 @@ +#TODO: Implement FFT and IFFT +fftn = None +fftshift = None +ifftn = None +ifftshift = None \ No newline at end of file diff --git a/python/jittor/compatibility/fx.py b/python/jittor/compatibility/fx.py new file mode 100644 index 00000000..0f0eb4f8 --- /dev/null +++ b/python/jittor/compatibility/fx.py @@ -0,0 +1,2 @@ +class Proxy: + pass \ No newline at end of file diff --git a/python/jittor/compatibility/gradscaler.py b/python/jittor/compatibility/gradscaler.py new file mode 100644 index 00000000..087d6bb2 --- /dev/null +++ b/python/jittor/compatibility/gradscaler.py @@ -0,0 +1,519 @@ +from collections import defaultdict, abc +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple, cast +import inspect +import warnings + +import jittor as jt +# import torch + +def _refresh_per_optimizer_state(): + return {} + + +class GradScaler: + _scale: Optional[jt.Var] + _grows_tracker: Optional[jt.Var] + _per_optimizer_states: Dict[int, Dict[str, Any]] + """ + An instance ``scaler`` of :class:`GradScaler` helps perform the steps of gradient scaling + conveniently. + + * ``scaler.scale(loss)`` multiplies a given loss by ``scaler``'s current scale factor. + * ``scaler.step(optimizer)`` safely unscales gradients and calls ``optimizer.step()``. + * ``scaler.update()`` updates ``scaler``'s scale factor. + + Example:: + + # Creates a GradScaler once at the beginning of training. + scaler = GradScaler() + + for epoch in epochs: + for input, target in data: + optimizer.zero_grad() + output = model(input) + loss = loss_fn(output, target) + + # Scales loss. Calls backward() on scaled loss to create scaled gradients. + scaler.scale(loss).backward() + + # scaler.step() first unscales gradients of the optimizer's params. + # If gradients don't contain infs/NaNs, optimizer.step() is then called, + # otherwise, optimizer.step() is skipped. + scaler.step(optimizer) + + # Updates the scale for next iteration. + scaler.update() + + See the :ref:`Automatic Mixed Precision examples` for usage + (along with autocasting) in more complex cases like gradient clipping, gradient accumulation, gradient penalty, + and multiple losses/optimizers. + + ``scaler`` dynamically estimates the scale factor each iteration. To minimize gradient underflow, + a large scale factor should be used. However, ``float16`` values can "overflow" (become inf or NaN) if + the scale factor is too large. Therefore, the optimal scale factor is the largest factor that can be used + without incurring inf or NaN gradient values. + ``scaler`` approximates the optimal scale factor over time by checking the gradients for infs and NaNs during every + ``scaler.step(optimizer)`` (or optional separate ``scaler.unscale_(optimizer)``, see :meth:`unscale_`). + + * If infs/NaNs are found, ``scaler.step(optimizer)`` skips the underlying ``optimizer.step()`` (so the params + themselves remain uncorrupted) and ``update()`` multiplies the scale by ``backoff_factor``. + + * If no infs/NaNs are found, ``scaler.step(optimizer)`` runs the underlying ``optimizer.step()`` as usual. + If ``growth_interval`` unskipped iterations occur consecutively, ``update()`` multiplies the scale by + ``growth_factor``. + + The scale factor often causes infs/NaNs to appear in gradients for the first few iterations as its + value calibrates. ``scaler.step`` will skip the underlying ``optimizer.step()`` for these + iterations. After that, step skipping should occur rarely (once every few hundred or thousand iterations). + + Args: + init_scale (float, optional, default=2.**16): Initial scale factor. + growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during + :meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations. + backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during + :meth:`update` if inf/NaN gradients occur in an iteration. + growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients + that must occur for the scale to be multiplied by ``growth_factor``. + enabled (bool, optional): If ``False``, disables gradient scaling. :meth:`step` simply + invokes the underlying ``optimizer.step()``, and other methods become no-ops. + Default: ``True`` + """ + def __init__(self, + init_scale=2.**16, + growth_factor=2.0, + backoff_factor=0.5, + growth_interval=2000, + enabled=True): + self._enabled = enabled + + if self._enabled: + assert growth_factor > 1.0, "The growth factor must be > 1.0." + assert backoff_factor < 1.0, "The backoff factor must be < 1.0." + + self._init_scale = init_scale + # self._scale will be lazily initialized during the first call to scale() + self._scale = None + self._growth_factor = growth_factor + self._backoff_factor = backoff_factor + self._growth_interval = growth_interval + self._init_growth_tracker = 0 + # self._growth_tracker will be lazily initialized during the first call to scale() + self._growth_tracker = None + self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) + + def _check_scale_growth_tracker(self, funcname) -> Tuple[jt.Var, jt.Var]: + fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration." + assert self._scale is not None, "Attempted {} but _scale is None. ".format(funcname) + fix + assert self._growth_tracker is not None, "Attempted {} but _growth_tracker is None. ".format(funcname) + fix + return (self._scale, self._growth_tracker) + + def _lazy_init_scale_growth_tracker(self): + assert self._growth_tracker is None, "_growth_tracker initialized before _scale" + self._scale = self._init_scale + self._growth_tracker = self._init_growth_tracker + + def scale(self, outputs): + """ + Multiplies ('scales') a tensor or list of tensors by the scale factor. + + Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned + unmodified. + + Args: + outputs (Tensor or iterable of Tensors): Outputs to scale. + """ + if not self._enabled: + return outputs + + + # Short-circuit for the common case. + if isinstance(outputs, jt.Var): + assert jt.flags.use_cuda == 1 + if self._scale is None: + self._lazy_init_scale_growth_tracker() + assert self._scale is not None + return outputs * self._scale + + def apply_scale(val): + if isinstance(val, jt.Var): + assert jt.flags.use_cuda == 1 + if self._scale is None: + self._lazy_init_scale_growth_tracker() + assert self._scale is not None + return val * self._scale + elif isinstance(val, abc.Iterable): + iterable = map(apply_scale, val) + if isinstance(val, (list, tuple)): + return type(val)(iterable) + else: + return iterable + else: + raise ValueError("outputs must be a Tensor or an iterable of Tensors") + + return apply_scale(outputs) + + def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): + with jt.no_grad(): + optimizer.pre_step() + for group in optimizer.param_groups: + for to_unscale in group["grads"]: + if to_unscale is None or isinstance(to_unscale,(int,float)): + continue + if (not allow_fp16) and str(to_unscale.dtype) == "float16": + raise ValueError("Attempting to unscale FP16 gradients.") + + if not (to_unscale.isinf().any()): + if inv_scale != 1.0: + to_unscale.update(to_unscale*inv_scale) + else: + found_inf = 1.0 + + return found_inf + + def unscale_(self, optimizer): + """ + Divides ("unscales") the optimizer's gradient tensors by the scale factor. + + :meth:`unscale_` is optional, serving cases where you need to + :ref:`modify or inspect gradients` + between the backward pass(es) and :meth:`step`. + If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`. + + Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients:: + + ... + scaler.scale(loss).backward() + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) + scaler.step(optimizer) + scaler.update() + + Args: + optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled. + + .. note:: + :meth:`unscale_` does not incur a CPU-GPU sync. + + .. warning:: + :meth:`unscale_` should only be called once per optimizer per :meth:`step` call, + and only after all gradients for that optimizer's assigned parameters have been accumulated. + Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError. + + .. warning:: + :meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute. + """ + if not self._enabled: + return + + self._check_scale_growth_tracker("unscale_") + + optimizer_state = self._per_optimizer_states[id(optimizer)] + + if hasattr(optimizer,"get_find_inf"): + return + # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64. + assert self._scale is not None + inv_scale = 1.0 / self._scale + found_inf = 0.0 + optimizer_state["found_inf_per_device"] = self._unscale_grads_(optimizer, inv_scale, found_inf, False) + + + def step(self, optimizer, *args, **kwargs): + """ + :meth:`step` carries out the following two operations: + + 1. Internally invokes ``unscale_(optimizer)`` (unless :meth:`unscale_` was explicitly called for ``optimizer`` + earlier in the iteration). As part of the :meth:`unscale_`, gradients are checked for infs/NaNs. + 2. If no inf/NaN gradients are found, invokes ``optimizer.step()`` using the unscaled + gradients. Otherwise, ``optimizer.step()`` is skipped to avoid corrupting the params. + + ``*args`` and ``**kwargs`` are forwarded to ``optimizer.step()``. + + Returns the return value of ``optimizer.step(*args, **kwargs)``. + + Args: + optimizer (torch.optim.Optimizer): Optimizer that applies the gradients. + args: Any arguments. + kwargs: Any keyword arguments. + + .. warning:: + Closure use is not currently supported. + """ + if (not self._enabled): + return optimizer.step(*args, **kwargs) + + if "closure" in kwargs: + raise RuntimeError("Closure use is not currently supported if GradScaler is enabled.") + + self._check_scale_growth_tracker("step") + + optimizer_state = self._per_optimizer_states[id(optimizer)] + retval = None + + if (hasattr(optimizer, "_step_supports_amp_scaling") and optimizer._step_supports_amp_scaling): + # This optimizer has customized scale-handling logic, so we can call optimizer.step() directly. + # The contract with custom optimizers is that their step() should accept an additional, + # optional grad_scaler kwarg. We append self to the kwargs so the custom optimizer has full information: + # it can query its own state, invoke unscale_ on itself, etc + # The contract above is being deprecated to avoid introducing `grad_scaler: GradScaler` argument + # to `Optimizer.step`. The new behavior is going to add two Tensor attributes of `grad_scale` + # and `found_inf` to the passed optimizer so that the optimizer can utilize those + # to skip the parameter updates or unscale gradients before updating parameters in + # the fused kernel, e.g. `FusedAdamMathFunctor`. + # In this behavior, `GradScaler._check_inf_per_device` is called if `OptState.READY`, + # while the method is expected to be called by users side, i.e. their optimizers. + kwargs_ = kwargs + has_grad_scaler_kwarg = "grad_scaler" in inspect.signature(optimizer.step).parameters + if has_grad_scaler_kwarg: + warnings.warn( + "GradScaler is going to stop passing itself as a keyword argument to the passed " + "optimizer. In the near future GradScaler registers `grad_scale: Tensor` and " + "`found_inf: Tensor` to the passed optimizer and let the optimizer use them directly.", + FutureWarning) + kwargs_.update({"grad_scaler": self}) + else: + if optimizer_state["stage"] is OptState.READY: + self._check_inf_per_device(optimizer) + scaler = self._get_scale_async() + found_inf = cast( + jt.Var, + sum([ + t for t in optimizer_state["found_inf_per_device"].values() + ]) + ) + optimizer.grad_scale = None if optimizer_state["stage"] == OptState.UNSCALED else scaler + optimizer.found_inf = found_inf + retval = optimizer.step(*args, **kwargs_) + optimizer_state["stage"] = OptState.STEPPED + if not has_grad_scaler_kwarg: + del optimizer.grad_scale + del optimizer.found_inf + return retval + + if hasattr(optimizer,"get_find_inf"): + optimizer.set_grad_scale(self._scale) + optimizer.step() + optimizer_state["found_inf_per_device"] = optimizer.get_find_inf() + return + + retval = None + if not optimizer_state["found_inf_per_device"]: + retval = optimizer.step(*args, **kwargs) + else: + optimizer.post_step() + + return retval + + + def update(self, new_scale=None): + """ + Updates the scale factor. + + If any optimizer steps were skipped the scale is multiplied by ``backoff_factor`` + to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively, + the scale is multiplied by ``growth_factor`` to increase it. + + Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not + used directly, it's used to fill GradScaler's internal scale tensor. So if + ``new_scale`` was a tensor, later in-place changes to that tensor will not further + affect the scale GradScaler uses internally.) + + Args: + new_scale (float or :class:`torch.cuda.FloatTensor`, optional, default=None): New scale factor. + + .. warning:: + :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has + been invoked for all optimizers used this iteration. + """ + if not self._enabled: + return + + _scale, _growth_tracker = self._check_scale_growth_tracker("update") + + if new_scale is not None: + # Accept a new user-defined scale. + if isinstance(new_scale, float): + self._scale.fill_(new_scale) # type: ignore[union-attr] + else: + reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False." + assert isinstance(new_scale, torch.cuda.FloatTensor), reason # type: ignore[attr-defined] + assert new_scale.numel() == 1, reason + assert new_scale.requires_grad is False, reason + self._scale.copy_(new_scale) # type: ignore[union-attr] + else: + # Consume shared inf/nan data collected from optimizers to update the scale. + # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous. + found_infs = [state["found_inf_per_device"] + for state in self._per_optimizer_states.values() + ] + + assert len(found_infs) > 0, "No inf checks were recorded prior to update." + + found_inf_combined = found_infs[0] + if len(found_infs) > 1: + for i in range(1, len(found_infs)): + found_inf_combined += found_infs[i] + + + current_scale = _scale + if found_inf_combined: + current_scale *=self._backoff_factor + _growth_tracker = 0 + else: + successful = _growth_tracker+1 + if successful == self._growth_interval: + new_scale = current_scale*self._growth_factor + if new_scale < 1e9: + current_scale = new_scale + _growth_tracker = 0 + else: + _growth_tracker = successful + + self._scale, self._growth_tracker = current_scale,_growth_tracker + + # To prepare for next iteration, clear the data collected from optimizers this iteration. + self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) + + def _get_scale_async(self): + return self._scale + + def get_scale(self): + """ + Returns a Python float containing the current scale, or 1.0 if scaling is disabled. + + .. warning:: + :meth:`get_scale` incurs a CPU-GPU sync. + """ + if self._enabled: + return self._init_scale if self._scale is None else self._get_scale_async() + else: + return 1.0 + + def get_growth_factor(self): + r""" + Returns a Python float containing the scale growth factor. + """ + return self._growth_factor + + def set_growth_factor(self, new_factor): + r""" + Args: + new_scale (float): Value to use as the new scale growth factor. + """ + self._growth_factor = new_factor + + def get_backoff_factor(self): + r""" + Returns a Python float containing the scale backoff factor. + """ + return self._backoff_factor + + def set_backoff_factor(self, new_factor): + r""" + Args: + new_scale (float): Value to use as the new scale backoff factor. + """ + self._backoff_factor = new_factor + + def get_growth_interval(self): + r""" + Returns a Python int containing the growth interval. + """ + return self._growth_interval + + def set_growth_interval(self, new_interval): + r""" + Args: + new_interval (int): Value to use as the new growth interval. + """ + self._growth_interval = new_interval + + def _get_growth_tracker(self): + if self._enabled: + return self._init_growth_tracker if self._growth_tracker is None else self._growth_tracker.item() + else: + return 0 + + def is_enabled(self): + r""" + Returns a bool indicating whether this instance is enabled. + """ + return self._enabled + + def state_dict(self): + r""" + Returns the state of the scaler as a :class:`dict`. It contains five entries: + + * ``"scale"`` - a Python float containing the current scale + * ``"growth_factor"`` - a Python float containing the current growth factor + * ``"backoff_factor"`` - a Python float containing the current backoff factor + * ``"growth_interval"`` - a Python int containing the current growth interval + * ``"_growth_tracker"`` - a Python int containing the number of recent consecutive unskipped steps. + + If this instance is not enabled, returns an empty dict. + + .. note:: + If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict` + should be called after :meth:`update`. + """ + return {"scale": self.get_scale(), + "growth_factor": self._growth_factor, + "backoff_factor": self._backoff_factor, + "growth_interval": self._growth_interval, + "_growth_tracker": self._get_growth_tracker()} if self._enabled else {} + + def load_state_dict(self, state_dict): + r""" + Loads the scaler state. If this instance is disabled, :meth:`load_state_dict` is a no-op. + + Args: + state_dict(dict): scaler state. Should be an object returned from a call to :meth:`state_dict`. + """ + if not self._enabled: + return + + if len(state_dict) == 0: + raise RuntimeError("The source state dict is empty, possibly because it was saved " + "from a disabled instance of GradScaler.") + + self._init_scale = state_dict["scale"] + if self._scale is not None: + self._scale.fill_(state_dict["scale"]) + self._growth_factor = state_dict["growth_factor"] + self._backoff_factor = state_dict["backoff_factor"] + self._growth_interval = state_dict["growth_interval"] + self._init_growth_tracker = state_dict["_growth_tracker"] + if self._growth_tracker is not None: + self._growth_tracker.fill_(state_dict["_growth_tracker"]) + + def __getstate__(self): + state = self.__dict__.copy() + if self._enabled: + assert len(self._per_optimizer_states) == 0, "A GradScaler instance may only be pickled at the beginning "\ + "of an iteration, or at the end after scaler.update()." + # Pickling _scale and _growth_tracker Tensors directly triggers + # "warnings.warn("pickle support for Storage will be removed in 1.5..." + # so instead, we set the unpickled instance up to reinitialize them lazily. + state['_init_scale'] = self.get_scale() + state['_init_growth_tracker'] = self._get_growth_tracker() + state['_scale'] = None + state['_growth_tracker'] = None + return state + + def __setstate__(self, state): + self.__dict__.update(state) + + def _check_inf_per_device(self, optimizer): + _scale, _ = self._check_scale_growth_tracker("_check_inf_per_device") + + dummy_inv_scale = 1.0 + found_inf = 0.0 + + self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = \ + self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True) + + return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] + + def _found_inf_per_device(self, optimizer): + return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] diff --git a/python/jittor/compatibility/gradscaler_old.py b/python/jittor/compatibility/gradscaler_old.py new file mode 100644 index 00000000..389be2cf --- /dev/null +++ b/python/jittor/compatibility/gradscaler_old.py @@ -0,0 +1,556 @@ +from collections import defaultdict, abc +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple, cast +import inspect +import warnings + +import jittor as jt +# import torch + + +__all__ = ["OptState", "GradScaler"] + + +# Defines default_factory for GradScaler's _per_optimizer_states defaultdict, +# as well as associated "enum" values. Prefers defining these at top level because +# - Lambdas can't be pickled, so we don't want to supply a lambda as the factory. +# - Defining READY, UNSCALED, STEPPED and _refresh_per_optimizer_state within GradScaler +# causes a circular reference, which we'd rather avoid. +class OptState(Enum): + READY = 0 + UNSCALED = 1 + STEPPED = 2 + + +def _refresh_per_optimizer_state(): + return {"stage": OptState.READY, "found_inf_per_device": {}} + + +class GradScaler: + _scale: Optional[jt.Var] + _grows_tracker: Optional[jt.Var] + _per_optimizer_states: Dict[int, Dict[str, Any]] + """ + An instance ``scaler`` of :class:`GradScaler` helps perform the steps of gradient scaling + conveniently. + + * ``scaler.scale(loss)`` multiplies a given loss by ``scaler``'s current scale factor. + * ``scaler.step(optimizer)`` safely unscales gradients and calls ``optimizer.step()``. + * ``scaler.update()`` updates ``scaler``'s scale factor. + + Example:: + + # Creates a GradScaler once at the beginning of training. + scaler = GradScaler() + + for epoch in epochs: + for input, target in data: + optimizer.zero_grad() + output = model(input) + loss = loss_fn(output, target) + + # Scales loss. Calls backward() on scaled loss to create scaled gradients. + scaler.scale(loss).backward() + + # scaler.step() first unscales gradients of the optimizer's params. + # If gradients don't contain infs/NaNs, optimizer.step() is then called, + # otherwise, optimizer.step() is skipped. + scaler.step(optimizer) + + # Updates the scale for next iteration. + scaler.update() + + See the :ref:`Automatic Mixed Precision examples` for usage + (along with autocasting) in more complex cases like gradient clipping, gradient accumulation, gradient penalty, + and multiple losses/optimizers. + + ``scaler`` dynamically estimates the scale factor each iteration. To minimize gradient underflow, + a large scale factor should be used. However, ``float16`` values can "overflow" (become inf or NaN) if + the scale factor is too large. Therefore, the optimal scale factor is the largest factor that can be used + without incurring inf or NaN gradient values. + ``scaler`` approximates the optimal scale factor over time by checking the gradients for infs and NaNs during every + ``scaler.step(optimizer)`` (or optional separate ``scaler.unscale_(optimizer)``, see :meth:`unscale_`). + + * If infs/NaNs are found, ``scaler.step(optimizer)`` skips the underlying ``optimizer.step()`` (so the params + themselves remain uncorrupted) and ``update()`` multiplies the scale by ``backoff_factor``. + + * If no infs/NaNs are found, ``scaler.step(optimizer)`` runs the underlying ``optimizer.step()`` as usual. + If ``growth_interval`` unskipped iterations occur consecutively, ``update()`` multiplies the scale by + ``growth_factor``. + + The scale factor often causes infs/NaNs to appear in gradients for the first few iterations as its + value calibrates. ``scaler.step`` will skip the underlying ``optimizer.step()`` for these + iterations. After that, step skipping should occur rarely (once every few hundred or thousand iterations). + + Args: + init_scale (float, optional, default=2.**16): Initial scale factor. + growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during + :meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations. + backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during + :meth:`update` if inf/NaN gradients occur in an iteration. + growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients + that must occur for the scale to be multiplied by ``growth_factor``. + enabled (bool, optional): If ``False``, disables gradient scaling. :meth:`step` simply + invokes the underlying ``optimizer.step()``, and other methods become no-ops. + Default: ``True`` + """ + def __init__(self, + init_scale=2.**16, + growth_factor=2.0, + backoff_factor=0.5, + growth_interval=2000, + enabled=True): + self._enabled = enabled + + if self._enabled: + assert growth_factor > 1.0, "The growth factor must be > 1.0." + assert backoff_factor < 1.0, "The backoff factor must be < 1.0." + + self._init_scale = init_scale + # self._scale will be lazily initialized during the first call to scale() + self._scale = None + self._growth_factor = growth_factor + self._backoff_factor = backoff_factor + self._growth_interval = growth_interval + self._init_growth_tracker = 0 + # self._growth_tracker will be lazily initialized during the first call to scale() + self._growth_tracker = None + self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) + + def _check_scale_growth_tracker(self, funcname) -> Tuple[jt.Var, jt.Var]: + fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration." + assert self._scale is not None, "Attempted {} but _scale is None. ".format(funcname) + fix + assert self._growth_tracker is not None, "Attempted {} but _growth_tracker is None. ".format(funcname) + fix + return (self._scale, self._growth_tracker) + + def _lazy_init_scale_growth_tracker(self): + assert self._growth_tracker is None, "_growth_tracker initialized before _scale" + self._scale = self._init_scale + self._growth_tracker = self._init_growth_tracker + + def scale(self, outputs): + """ + Multiplies ('scales') a tensor or list of tensors by the scale factor. + + Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned + unmodified. + + Args: + outputs (Tensor or iterable of Tensors): Outputs to scale. + """ + print("scale") + if not self._enabled: + return outputs + + + # Short-circuit for the common case. + if isinstance(outputs, jt.Var): + assert jt.flags.use_cuda == 1 + if self._scale is None: + self._lazy_init_scale_growth_tracker() + assert self._scale is not None + return outputs * self._scale + + def apply_scale(val): + if isinstance(val, jt.Var): + assert jt.flags.use_cuda == 1 + if self._scale is None: + self._lazy_init_scale_growth_tracker() + assert self._scale is not None + return val * self._scale + elif isinstance(val, abc.Iterable): + iterable = map(apply_scale, val) + if isinstance(val, (list, tuple)): + return type(val)(iterable) + else: + return iterable + else: + raise ValueError("outputs must be a Tensor or an iterable of Tensors") + + return apply_scale(outputs) + + def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): + + # To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype. + # There could be hundreds of grads, so we'd like to iterate through them just once. + # However, we don't know their devices or dtypes in advance. + + # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict + # Google says mypy struggles with defaultdicts type annotations. + with jt.no_grad(): + optimizer.pre_step() + for group in optimizer.param_groups: + for to_unscale in group["grads"]: + if to_unscale is None or isinstance(to_unscale,(int,float)): + continue + if (not allow_fp16) and str(to_unscale.dtype) == "float16": + raise ValueError("Attempting to unscale FP16 gradients.") + + if not (to_unscale.isinf().any()): + if inv_scale != 1.0: + to_unscale.update(to_unscale*inv_scale) + else: + found_inf = 1.0 + + return found_inf + + def unscale_(self, optimizer): + """ + Divides ("unscales") the optimizer's gradient tensors by the scale factor. + + :meth:`unscale_` is optional, serving cases where you need to + :ref:`modify or inspect gradients` + between the backward pass(es) and :meth:`step`. + If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`. + + Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients:: + + ... + scaler.scale(loss).backward() + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) + scaler.step(optimizer) + scaler.update() + + Args: + optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled. + + .. note:: + :meth:`unscale_` does not incur a CPU-GPU sync. + + .. warning:: + :meth:`unscale_` should only be called once per optimizer per :meth:`step` call, + and only after all gradients for that optimizer's assigned parameters have been accumulated. + Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError. + + .. warning:: + :meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute. + """ + if not self._enabled: + return + + self._check_scale_growth_tracker("unscale_") + + optimizer_state = self._per_optimizer_states[id(optimizer)] + + if optimizer_state["stage"] is OptState.UNSCALED: + raise RuntimeError("unscale_() has already been called on this optimizer since the last update().") + elif optimizer_state["stage"] is OptState.STEPPED: + raise RuntimeError("unscale_() is being called after step().") + + + # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64. + assert self._scale is not None + inv_scale = 1.0 / self._scale + found_inf = 0.0 + optimizer_state["found_inf_per_device"] = self._unscale_grads_(optimizer, inv_scale, found_inf, False) + optimizer_state["stage"] = OptState.UNSCALED + + def _maybe_opt_step(self, optimizer, optimizer_state, *args, **kwargs): + retval = None + if not optimizer_state["found_inf_per_device"]: + retval = optimizer.step(*args, **kwargs) + else: + optimizer.post_step() + + return retval + + def step(self, optimizer, *args, **kwargs): + """ + :meth:`step` carries out the following two operations: + + 1. Internally invokes ``unscale_(optimizer)`` (unless :meth:`unscale_` was explicitly called for ``optimizer`` + earlier in the iteration). As part of the :meth:`unscale_`, gradients are checked for infs/NaNs. + 2. If no inf/NaN gradients are found, invokes ``optimizer.step()`` using the unscaled + gradients. Otherwise, ``optimizer.step()`` is skipped to avoid corrupting the params. + + ``*args`` and ``**kwargs`` are forwarded to ``optimizer.step()``. + + Returns the return value of ``optimizer.step(*args, **kwargs)``. + + Args: + optimizer (torch.optim.Optimizer): Optimizer that applies the gradients. + args: Any arguments. + kwargs: Any keyword arguments. + + .. warning:: + Closure use is not currently supported. + """ + if (not self._enabled): + return optimizer.step(*args, **kwargs) + + if "closure" in kwargs: + raise RuntimeError("Closure use is not currently supported if GradScaler is enabled.") + + self._check_scale_growth_tracker("step") + + optimizer_state = self._per_optimizer_states[id(optimizer)] + + if optimizer_state["stage"] is OptState.STEPPED: + raise RuntimeError("step() has already been called since the last update().") + + retval = None + + if (hasattr(optimizer, "_step_supports_amp_scaling") and optimizer._step_supports_amp_scaling): + # This optimizer has customized scale-handling logic, so we can call optimizer.step() directly. + # The contract with custom optimizers is that their step() should accept an additional, + # optional grad_scaler kwarg. We append self to the kwargs so the custom optimizer has full information: + # it can query its own state, invoke unscale_ on itself, etc + # The contract above is being deprecated to avoid introducing `grad_scaler: GradScaler` argument + # to `Optimizer.step`. The new behavior is going to add two Tensor attributes of `grad_scale` + # and `found_inf` to the passed optimizer so that the optimizer can utilize those + # to skip the parameter updates or unscale gradients before updating parameters in + # the fused kernel, e.g. `FusedAdamMathFunctor`. + # In this behavior, `GradScaler._check_inf_per_device` is called if `OptState.READY`, + # while the method is expected to be called by users side, i.e. their optimizers. + kwargs_ = kwargs + has_grad_scaler_kwarg = "grad_scaler" in inspect.signature(optimizer.step).parameters + if has_grad_scaler_kwarg: + warnings.warn( + "GradScaler is going to stop passing itself as a keyword argument to the passed " + "optimizer. In the near future GradScaler registers `grad_scale: Tensor` and " + "`found_inf: Tensor` to the passed optimizer and let the optimizer use them directly.", + FutureWarning) + kwargs_.update({"grad_scaler": self}) + else: + if optimizer_state["stage"] is OptState.READY: + self._check_inf_per_device(optimizer) + scaler = self._get_scale_async() + found_inf = cast( + jt.Var, + sum([ + t for t in optimizer_state["found_inf_per_device"].values() + ]) + ) + optimizer.grad_scale = None if optimizer_state["stage"] == OptState.UNSCALED else scaler + optimizer.found_inf = found_inf + retval = optimizer.step(*args, **kwargs_) + optimizer_state["stage"] = OptState.STEPPED + if not has_grad_scaler_kwarg: + del optimizer.grad_scale + del optimizer.found_inf + return retval + + + if optimizer_state["stage"] is OptState.READY: + self.unscale_(optimizer) + + assert "found_inf_per_device" in optimizer_state, "No inf checks were recorded for this optimizer." + + retval = self._maybe_opt_step(optimizer, optimizer_state, *args, **kwargs) + + optimizer_state["stage"] = OptState.STEPPED + + return retval + + def update(self, new_scale=None): + """ + Updates the scale factor. + + If any optimizer steps were skipped the scale is multiplied by ``backoff_factor`` + to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively, + the scale is multiplied by ``growth_factor`` to increase it. + + Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not + used directly, it's used to fill GradScaler's internal scale tensor. So if + ``new_scale`` was a tensor, later in-place changes to that tensor will not further + affect the scale GradScaler uses internally.) + + Args: + new_scale (float or :class:`torch.cuda.FloatTensor`, optional, default=None): New scale factor. + + .. warning:: + :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has + been invoked for all optimizers used this iteration. + """ + if not self._enabled: + return + + _scale, _growth_tracker = self._check_scale_growth_tracker("update") + + if new_scale is not None: + # Accept a new user-defined scale. + if isinstance(new_scale, float): + self._scale.fill_(new_scale) # type: ignore[union-attr] + else: + reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False." + assert isinstance(new_scale, torch.cuda.FloatTensor), reason # type: ignore[attr-defined] + assert new_scale.numel() == 1, reason + assert new_scale.requires_grad is False, reason + self._scale.copy_(new_scale) # type: ignore[union-attr] + else: + # Consume shared inf/nan data collected from optimizers to update the scale. + # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous. + found_infs = [state["found_inf_per_device"] + for state in self._per_optimizer_states.values() + ] + + assert len(found_infs) > 0, "No inf checks were recorded prior to update." + + found_inf_combined = found_infs[0] + if len(found_infs) > 1: + for i in range(1, len(found_infs)): + found_inf_combined += found_infs[i] + + + current_scale = _scale + if found_inf_combined: + current_scale *=self._backoff_factor + _growth_tracker = 0 + else: + successful = _growth_tracker+1 + if successful == self._growth_interval: + new_scale = current_scale*self._growth_factor + if new_scale < 1e9: + current_scale = new_scale + _growth_tracker = 0 + else: + _growth_tracker = successful + + self._scale, self._growth_tracker = current_scale,_growth_tracker + + # To prepare for next iteration, clear the data collected from optimizers this iteration. + self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) + + def _get_scale_async(self): + return self._scale + + def get_scale(self): + """ + Returns a Python float containing the current scale, or 1.0 if scaling is disabled. + + .. warning:: + :meth:`get_scale` incurs a CPU-GPU sync. + """ + if self._enabled: + return self._init_scale if self._scale is None else self._get_scale_async() + else: + return 1.0 + + def get_growth_factor(self): + r""" + Returns a Python float containing the scale growth factor. + """ + return self._growth_factor + + def set_growth_factor(self, new_factor): + r""" + Args: + new_scale (float): Value to use as the new scale growth factor. + """ + self._growth_factor = new_factor + + def get_backoff_factor(self): + r""" + Returns a Python float containing the scale backoff factor. + """ + return self._backoff_factor + + def set_backoff_factor(self, new_factor): + r""" + Args: + new_scale (float): Value to use as the new scale backoff factor. + """ + self._backoff_factor = new_factor + + def get_growth_interval(self): + r""" + Returns a Python int containing the growth interval. + """ + return self._growth_interval + + def set_growth_interval(self, new_interval): + r""" + Args: + new_interval (int): Value to use as the new growth interval. + """ + self._growth_interval = new_interval + + def _get_growth_tracker(self): + if self._enabled: + return self._init_growth_tracker if self._growth_tracker is None else self._growth_tracker.item() + else: + return 0 + + def is_enabled(self): + r""" + Returns a bool indicating whether this instance is enabled. + """ + return self._enabled + + def state_dict(self): + r""" + Returns the state of the scaler as a :class:`dict`. It contains five entries: + + * ``"scale"`` - a Python float containing the current scale + * ``"growth_factor"`` - a Python float containing the current growth factor + * ``"backoff_factor"`` - a Python float containing the current backoff factor + * ``"growth_interval"`` - a Python int containing the current growth interval + * ``"_growth_tracker"`` - a Python int containing the number of recent consecutive unskipped steps. + + If this instance is not enabled, returns an empty dict. + + .. note:: + If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict` + should be called after :meth:`update`. + """ + return {"scale": self.get_scale(), + "growth_factor": self._growth_factor, + "backoff_factor": self._backoff_factor, + "growth_interval": self._growth_interval, + "_growth_tracker": self._get_growth_tracker()} if self._enabled else {} + + def load_state_dict(self, state_dict): + r""" + Loads the scaler state. If this instance is disabled, :meth:`load_state_dict` is a no-op. + + Args: + state_dict(dict): scaler state. Should be an object returned from a call to :meth:`state_dict`. + """ + if not self._enabled: + return + + if len(state_dict) == 0: + raise RuntimeError("The source state dict is empty, possibly because it was saved " + "from a disabled instance of GradScaler.") + + self._init_scale = state_dict["scale"] + if self._scale is not None: + self._scale.fill_(state_dict["scale"]) + self._growth_factor = state_dict["growth_factor"] + self._backoff_factor = state_dict["backoff_factor"] + self._growth_interval = state_dict["growth_interval"] + self._init_growth_tracker = state_dict["_growth_tracker"] + if self._growth_tracker is not None: + self._growth_tracker.fill_(state_dict["_growth_tracker"]) + + def __getstate__(self): + state = self.__dict__.copy() + if self._enabled: + assert len(self._per_optimizer_states) == 0, "A GradScaler instance may only be pickled at the beginning "\ + "of an iteration, or at the end after scaler.update()." + # Pickling _scale and _growth_tracker Tensors directly triggers + # "warnings.warn("pickle support for Storage will be removed in 1.5..." + # so instead, we set the unpickled instance up to reinitialize them lazily. + state['_init_scale'] = self.get_scale() + state['_init_growth_tracker'] = self._get_growth_tracker() + state['_scale'] = None + state['_growth_tracker'] = None + return state + + def __setstate__(self, state): + self.__dict__.update(state) + + def _check_inf_per_device(self, optimizer): + _scale, _ = self._check_scale_growth_tracker("_check_inf_per_device") + + dummy_inv_scale = 1.0 + found_inf = 0.0 + + self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = \ + self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True) + + return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] + + def _found_inf_per_device(self, optimizer): + return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] diff --git a/python/jittor/compatibility/misc.py b/python/jittor/compatibility/misc.py new file mode 100644 index 00000000..8e9ed20d --- /dev/null +++ b/python/jittor/compatibility/misc.py @@ -0,0 +1,12 @@ +import math + +def _jit_set_profiling_mode(x): pass +def _jit_set_profiling_executor(x): pass +def _jit_override_can_fuse_on_cpu(x): pass +def _jit_override_can_fuse_on_gpu(x): pass + +def script(func): + return func + +inf = math.inf +nan = math.nan \ No newline at end of file diff --git a/python/jittor/compatibility/nn/__init__.py b/python/jittor/compatibility/nn/__init__.py new file mode 100644 index 00000000..ae0ff3ae --- /dev/null +++ b/python/jittor/compatibility/nn/__init__.py @@ -0,0 +1,281 @@ +import jtorch +from typing import List, Optional, Tuple, Iterable, Iterator, Mapping, Any, overload, TypeVar, Dict +from typing_extensions import Self +import jittor as jt +from jtorch import make_module, Tensor, ModuleMisc, wrapper +#from . import init +from jittor import Function +import operator +import warnings + +for k,v in jt.nn.__dict__.items(): + if callable(v): + globals()[k] = wrapper(v) + +for k,v in jt.nn.__dict__.items(): + if isinstance(v, type) and issubclass(v, jt.Module): + globals()[k] = make_module(v) + +from collections import OrderedDict +from collections import abc as container_abcs + +class Module(ModuleMisc, jt.Module): + + def __call__(self, *args, **kw): + return self.execute(*args, **kw) + + def execute(self, *args, **kw): + return self.forward(*args, **kw) + + def get_submodule(self, target: str): + if target == "": + return self + + atoms: List[str] = target.split(".") + mod: jt.nn.Module = self + + for item in atoms: + if not hasattr(mod, item): + raise AttributeError(mod._get_name() + " has no " + "attribute `" + item + "`") + + mod = getattr(mod, item) + + if not isinstance(mod, jt.nn.Module): + raise AttributeError("`" + item + "` is not " + "an nn.Module") + return mod + + + +def Parameter(x:Tensor, requires_grad:bool=True) -> Tensor: + x = x.clone() + x.requires_grad = requires_grad + x.retains_grad = requires_grad + return x + +def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False): + return jt.nn.embedding(input, weight) + +def dropout(x, p=0.5, training=False): + return jt.nn.dropout(x, p, training) + + +class Flatten(Module): + ''' Flattens the contiguous range of dimensions in a Var. + :param start_dim: the first dimension to be flattened. Defaults: 1. + :type start_dim: int + :param end_dim: the last dimension to be flattened. Defaults: -1. + :type end_dim: int + ''' + def __init__(self, start_dim=1, end_dim=-1): + self.start_dim = start_dim + self.end_dim = end_dim + + def forward(self, x) -> jt.Var: + return x.flatten(self.start_dim, self.end_dim) + +class _IncompatibleKeys: + def __init__(self, missing_keys, unexpected_keys): + self.missing_keys = missing_keys + self.unexpected_keys = unexpected_keys + +_BatchNorm = None + +#from . import utils +normalize = wrapper(jt.normalize) + +T = TypeVar('T', bound=Module) + +class ModuleDict(Module): + _modules: Dict[str, Module] # type: ignore[assignment] + + def __init__(self, modules: Optional[Mapping[str, Module]] = None) -> None: + super().__init__() + if modules is not None: + self.update(modules) + + def __getitem__(self, key: str) -> Module: + return self._modules[key] + + def __setitem__(self, key: str, module: Module) -> None: + self.add_module(key, module) + + def __delitem__(self, key: str) -> None: + del self._modules[key] + + def __len__(self) -> int: + return len(self._modules) + + def __iter__(self) -> Iterator[str]: + return iter(self._modules) + + def __contains__(self, key: str) -> bool: + return key in self._modules + + def clear(self) -> None: + """Remove all items from the ModuleDict.""" + self._modules.clear() + + def pop(self, key: str) -> Module: + r"""Remove key from the ModuleDict and return its module. + + Args: + key (str): key to pop from the ModuleDict + """ + v = self[key] + del self[key] + return v + + def keys(self) -> Iterable[str]: + r"""Return an iterable of the ModuleDict keys.""" + return self._modules.keys() + + def items(self) -> Iterable[Tuple[str, Module]]: + r"""Return an iterable of the ModuleDict key/value pairs.""" + return self._modules.items() + + def values(self) -> Iterable[Module]: + r"""Return an iterable of the ModuleDict values.""" + return self._modules.values() + + def update(self, modules: Mapping[str, Module]) -> None: + r"""Update the :class:`~torch.nn.ModuleDict` with key-value pairs from a mapping, overwriting existing keys. + + .. note:: + If :attr:`modules` is an ``OrderedDict``, a :class:`~torch.nn.ModuleDict`, or + an iterable of key-value pairs, the order of new elements in it is preserved. + + Args: + modules (iterable): a mapping (dictionary) from string to :class:`~torch.nn.Module`, + or an iterable of key-value pairs of type (string, :class:`~torch.nn.Module`) + """ + if not isinstance(modules, container_abcs.Iterable): + raise TypeError("ModuleDict.update should be called with an " + "iterable of key/value pairs, but got " + + type(modules).__name__) + + if isinstance(modules, (OrderedDict, ModuleDict, container_abcs.Mapping)): + for key, module in modules.items(): + self[key] = module + else: + # modules here can be a list with two items + for j, m in enumerate(modules): + if not isinstance(m, container_abcs.Iterable): + raise TypeError("ModuleDict update sequence element " + "#" + str(j) + " should be Iterable; is" + + type(m).__name__) + if not len(m) == 2: + raise ValueError("ModuleDict update sequence element " + "#" + str(j) + " has length " + str(len(m)) + + "; 2 is required") + # modules can be Mapping (what it's typed at), or a list: [(name1, module1), (name2, module2)] + # that's too cumbersome to type correctly with overloads, so we add an ignore here + self[m[0]] = m[1] # type: ignore[assignment] + + # remove forward alltogether to fallback on Module's _forward_unimplemented + + +class ParameterList(Module): + + def __init__(self, values: Optional[Iterable[Any]] = None) -> None: + super().__init__() + self._size = 0 + if values is not None: + self += values + + def _get_abs_string_index(self, idx): + """Get the absolute index for the list of modules.""" + idx = operator.index(idx) + if not (-len(self) <= idx < len(self)): + raise IndexError(f'index {idx} is out of range') + if idx < 0: + idx += len(self) + return str(idx) + + @overload + def __getitem__(self, idx: int) -> Any: + ... + + @overload + def __getitem__(self: T, idx: slice) -> T: + ... + + def __getitem__(self, idx): + if isinstance(idx, slice): + start, stop, step = idx.indices(len(self)) + out = self.__class__() + for i in range(start, stop, step): + out.append(self[i]) + return out + else: + idx = self._get_abs_string_index(idx) + return getattr(self, str(idx)) + + def __setitem__(self, idx: int, param: Any) -> None: + # Note that all other function that add an entry to the list part of + # the ParameterList end up here. So this is the only place where we need + # to wrap things into Parameter if needed. + # Objects added via setattr() are not in the list part and thus won't + # call into this function. + idx = self._get_abs_string_index(idx) + if isinstance(param, jt.Var) and not isinstance(param, Parameter): + param = Parameter(param) + return setattr(self, str(idx), param) + + def __len__(self) -> int: + return self._size + + def __iter__(self) -> Iterator[Any]: + return iter(self[i] for i in range(len(self))) + + def __iadd__(self, parameters: Iterable[Any]) -> Self: + return self.extend(parameters) + + def __dir__(self): + keys = super().__dir__() + keys = [key for key in keys if not key.isdigit()] + return keys + + def append(self, value: Any) -> 'ParameterList': + """Append a given value at the end of the list. + + Args: + value (Any): value to append + """ + new_idx = len(self) + self._size += 1 + self[new_idx] = value + return self + + def extend(self, values: Iterable[Any]) -> Self: + """Append values from a Python iterable to the end of the list. + + Args: + values (iterable): iterable of values to append + """ + # Tensor is an iterable but we never want to unpack it here + if not isinstance(values, container_abcs.Iterable) or isinstance(values, jt.Var): + raise TypeError("ParameterList.extend should be called with an " + "iterable, but got " + type(values).__name__) + for value in values: + self.append(value) + return self + + def extra_repr(self) -> str: + child_lines = [] + for k, p in enumerate(self): + if isinstance(p, jt.Var): + size_str = 'x'.join(str(size) for size in p.size()) + parastr = '{} containing: [{} of size {}{}]'.format( + "Parameter" if isinstance(p, Parameter) else "Tensor", + p.dtype, size_str, "cuda" if jt.flags.use_cuda else "cpu") + child_lines.append(' (' + str(k) + '): ' + parastr) + else: + child_lines.append(' (' + str(k) + '): Object of type: ' + type(p).__name__) + + tmpstr = '\n'.join(child_lines) + return tmpstr + + def __call__(self, *args, **kwargs): + raise RuntimeError('ParameterList should not be called.') \ No newline at end of file diff --git a/python/jittor/compatibility/nn/init.py b/python/jittor/compatibility/nn/init.py new file mode 100644 index 00000000..3b9f0907 --- /dev/null +++ b/python/jittor/compatibility/nn/init.py @@ -0,0 +1,16 @@ +import jittor as jt + +for k,v in jt.nn.init.__dict__.items(): + if callable(v): + globals()[k] = v + + +normal = gauss +normal_ = gauss_ +xavier_normal = xavier_gauss +xavier_normal_ = xavier_gauss_ +zeros_ = zero_ + + +jt.Var.normal_ = normal_ + diff --git a/python/jittor/compatibility/nn/utils/__init__.py b/python/jittor/compatibility/nn/utils/__init__.py new file mode 100644 index 00000000..83409f5f --- /dev/null +++ b/python/jittor/compatibility/nn/utils/__init__.py @@ -0,0 +1 @@ +from . import rnn \ No newline at end of file diff --git a/python/jittor/compatibility/nn/utils/rnn.py b/python/jittor/compatibility/nn/utils/rnn.py new file mode 100644 index 00000000..b32da8c3 --- /dev/null +++ b/python/jittor/compatibility/nn/utils/rnn.py @@ -0,0 +1,20 @@ +import jittor as jt + +PackedSequence = None + +def pad_sequence(sequences,batch_first=False,padding_value=0.0): + max_f = max([len(s) for s in sequences]) + # max_f = 512 + b = len(sequences) + if batch_first: + ret = sequences[0].new_full([b,max_f,]+list(sequences[0].shape[1:]),padding_value) + for i,s in enumerate(sequences): + ret[i,:len(s)] = s + else: + ret = sequences[0].new_full([max_f,b,]+list(sequences[0].shape[1:]),padding_value) + for i,s in enumerate(sequences): + ret[:len(s),i] = s + # print(ret.shape) + # ret = ret[:,:406] + return ret + \ No newline at end of file diff --git a/python/jittor/compatibility/optim.py b/python/jittor/compatibility/optim.py new file mode 100644 index 00000000..2410917f --- /dev/null +++ b/python/jittor/compatibility/optim.py @@ -0,0 +1,1854 @@ +import jittor as jt +import math +from jittor.optim import * +from functools import partial + +class Optimizer(jt.optim.Optimizer): + def pre_step(self, loss=None, retain_graph=False): + jt.flags.node_order = 1 + params_has_grad = [] + for pg in self.param_groups: + pg["grads"] = [ jt.zeros_like(p) if p.grad is None else p.grad#.float32() + for p in pg["params"] ] + for p in pg["params"]: + if p.requires_grad: + params_has_grad.append(p) + jt.sync(params_has_grad) + self.n_step += 1 + + def zero_grad(self): + for pg in self.param_groups: + pg["grads"] = [ None for p in pg["params"] ] + for p in pg["params"]: p.grad = None + + def post_step(self): + jt.flags.node_order = 0 + + def clip_grad_norm(self, max_norm:float, norm_type:int=2): + r"""Clips gradient norm of this optimizer. + The norm is computed over all gradients together. + + Args: + max_norm (float or int): max norm of the gradients + norm_type (int): 1-norm or 2-norm + + Example:: + + a = jt.ones(2) + opt = jt.optim.SGD([a], 0.1) + + loss = a*a + opt.zero_grad() + opt.backward(loss) + + print(opt.param_groups[0]['grads'][0].norm()) # output: 2.83 + opt.clip_grad_norm(0.01, 2) + print(opt.param_groups[0]['grads'][0].norm()) # output: 0.01 + + opt.step() + + """ + self.pre_step(None) + grads = [] + for pg in self.param_groups: + for p, g in zip(pg["params"], pg["grads"]): + if p.is_stop_grad(): continue + grads.append(g.flatten()) + if len(grads) == 0: return + total_norm = jt.norm(jt.concat(grads), norm_type) + clip_coef = jt.minimum(max_norm / (total_norm + 1e-6), 1.0) + for pg in self.param_groups: + for p, g in zip(pg["params"], pg["grads"]): + if p.is_stop_grad(): continue + g.update(g*clip_coef) + + +class AdamW(Optimizer): + def __init__(self, params, lr, eps=1e-8, betas=(0.9, 0.999), weight_decay=0,use_fp32=True): + print("lr:", lr) + super().__init__(params, lr) + self.eps = eps + self.betas = betas + self.weight_decay = weight_decay + + self.use_fp32 = use_fp32 + # assert weight_decay==0, "weight_decay is not supported yet" + + # initialize required arguments for each param_groups + for pg in self.param_groups: + values = pg["values"] = [] + m = pg["m"] = [] + mp = pg['masterparams'] = [] + for p in pg["params"]: + values.append(jt.zeros(p.shape, "float32" if self.use_fp32 else p.dtype).stop_grad()) + m.append(jt.zeros(p.shape, "float32" if self.use_fp32 else p.dtype).stop_grad()) + if self.use_fp32: + mp.append(p.detach().clone().stop_grad()) + + def add_param_group(self, group): + values = group["values"] = [] + m = group["m"] = [] + mp = group['masterparams'] = [] + for p in group["params"]: + values.append(jt.zeros(p.shape, "float32" if self.use_fp32 else p.dtype).stop_grad()) + m.append(jt.zeros(p.shape, "float32" if self.use_fp32 else p.dtype).stop_grad()) + if self.use_fp32: + mp.append(p.detach().clone().stop_grad()) + self.param_groups.append(group) + + def step(self, loss=None, retain_graph=False): + self.pre_step(loss, retain_graph) + if loss is None: + self.n_step += 1 + n = float(self.n_step) + for pg in self.param_groups: + # get arguments from each param_groups + lr = pg.get("lr", self.lr) + eps = pg.get("eps", self.eps) + weight_decay = pg.get("weight_decay", self.weight_decay) + b0, b1 = pg.get("betas", self.betas) + + for p, g, v, m,mp in zip(pg["params"], pg["grads"], pg["values"], pg["m"],pg['masterparams']): + if p.is_stop_grad(): continue + #if g.abs().sum().item() < 1e-8: continue + #import pdb; pdb.set_trace() + c_p = (mp * (1 - lr * weight_decay)) + mp.update(c_p) + if self.use_fp32: + g = g.float32() + bias_correction1 = 1 - b0 ** n + bias_correction2 = 1 - b1 ** n + m.update(b0 * m + (1-b0) * g) #exp_avg + v.update(b1 * v + (1-b1) * g * g) #exp_avg_sq + denom = jt.sqrt(v) / jt.sqrt(bias_correction2) + eps + step_size = lr / bias_correction1 + new_p = (mp - step_size * m / denom) + mp.update(new_p) + p.update(mp.cast(p.dtype)) + self.post_step() + +for k,v in jt.optim.__dict__.items(): + if k == "AdamW":continue + if isinstance(v, type) and issubclass(v, jt.optim.Optimizer) and \ + not v is jt.optim.Optimizer: + class OptimWrap(v, Optimizer): + pass + globals()[k] = OptimWrap + + +class Adagrad(Optimizer): + pass + + + +import types +import math +from functools import wraps +import warnings +import weakref +from collections import Counter +from bisect import bisect_right + + +class LRScheduler: + + def __init__(self, optimizer, last_epoch=-1, verbose=False): + + # Attach optimizer + if not isinstance(optimizer, Optimizer): + raise TypeError('{} is not an Optimizer'.format( + type(optimizer).__name__)) + self.optimizer = optimizer + + # Initialize epoch and base learning rates + if last_epoch == -1: + for group in optimizer.param_groups: + group.setdefault('initial_lr', group.get("lr",optimizer.lr)) + else: + for i, group in enumerate(optimizer.param_groups): + if 'initial_lr' not in group: + raise KeyError("param 'initial_lr' is not specified " + "in param_groups[{}] when resuming an optimizer".format(i)) + self.base_lrs = [group['initial_lr'] for group in optimizer.param_groups] + self.last_epoch = last_epoch + + # Following https://github.com/pytorch/pytorch/issues/20124 + # We would like to ensure that `lr_scheduler.step()` is called after + # `optimizer.step()` + def with_counter(method): + if getattr(method, '_with_counter', False): + # `optimizer.step()` has already been replaced, return. + return method + + # Keep a weak reference to the optimizer instance to prevent + # cyclic references. + instance_ref = weakref.ref(method.__self__) + # Get the unbound method for the same purpose. + func = method.__func__ + cls = instance_ref().__class__ + del method + + @wraps(func) + def wrapper(*args, **kwargs): + instance = instance_ref() + instance._step_count += 1 + wrapped = func.__get__(instance, cls) + return wrapped(*args, **kwargs) + + # Note that the returned function here is no longer a bound method, + # so attributes like `__func__` and `__self__` no longer exist. + wrapper._with_counter = True + return wrapper + + self.optimizer.step = with_counter(self.optimizer.step) + self.verbose = verbose + + self._initial_step() + + def _initial_step(self): + """Initialize step counts and performs a step""" + self.optimizer._step_count = 0 + self._step_count = 0 + self.step() + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + """ + return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + self.__dict__.update(state_dict) + + def get_last_lr(self): + """ Return last computed learning rate by current scheduler. + """ + return self._last_lr + + def get_lr(self): + # Compute learning rate using chainable form of the scheduler + raise NotImplementedError + + def print_lr(self, is_verbose, group, lr, epoch=None): + """Display the current learning rate. + """ + if is_verbose: + if epoch is None: + print('Adjusting learning rate' + ' of group {} to {:.4e}.'.format(group, lr)) + else: + epoch_str = ("%.2f" if isinstance(epoch, float) else + "%.5d") % epoch + print('Epoch {}: adjusting learning rate' + ' of group {} to {:.4e}.'.format(epoch_str, group, lr)) + + + def step(self, epoch=None): + # Raise a warning if old pattern is detected + # https://github.com/pytorch/pytorch/issues/20124 + if self._step_count == 1: + if not hasattr(self.optimizer.step, "_with_counter"): + warnings.warn("Seems like `optimizer.step()` has been overridden after learning rate scheduler " + "initialization. Please, make sure to call `optimizer.step()` before " + "`lr_scheduler.step()`. See more details at " + "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning) + + # Just check if there were two first lr_scheduler.step() calls before optimizer.step() + elif self.optimizer._step_count < 1: + warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. " + "In PyTorch 1.1.0 and later, you should call them in the opposite order: " + "`optimizer.step()` before `lr_scheduler.step()`. Failure to do this " + "will result in PyTorch skipping the first value of the learning rate schedule. " + "See more details at " + "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning) + self._step_count += 1 + + with _enable_get_lr_call(self): + if epoch is None: + self.last_epoch += 1 + values = self.get_lr() + else: + warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning) + self.last_epoch = epoch + if hasattr(self, "_get_closed_form_lr"): + values = self._get_closed_form_lr() + else: + values = self.get_lr() + + for i, data in enumerate(zip(self.optimizer.param_groups, values)): + param_group, lr = data + param_group['lr'] = lr + self.print_lr(self.verbose, i, lr, epoch) + + self._last_lr = [group['lr'] for group in self.optimizer.param_groups] + +# Including _LRScheduler for backwards compatibility +# Subclass instead of assign because we want __name__ of _LRScheduler to be _LRScheduler (assigning would make it LRScheduler). +class _LRScheduler(LRScheduler): + pass + + +class _enable_get_lr_call: + + def __init__(self, o): + self.o = o + + def __enter__(self): + self.o._get_lr_called_within_step = True + return self + + def __exit__(self, type, value, traceback): + self.o._get_lr_called_within_step = False + + +class LambdaLR(LRScheduler): + """Sets the learning rate of each parameter group to the initial lr + times a given function. When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + lr_lambda (function or list): A function which computes a multiplicative + factor given an integer parameter epoch, or a list of such + functions, one for each group in optimizer.param_groups. + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + Example: + >>> # xdoctest: +SKIP + >>> # Assuming optimizer has two groups. + >>> lambda1 = lambda epoch: epoch // 30 + >>> lambda2 = lambda epoch: 0.95 ** epoch + >>> scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2]) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__(self, optimizer, lr_lambda, last_epoch=-1, verbose=False): + self.optimizer = optimizer + + if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple): + self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups) + else: + if len(lr_lambda) != len(optimizer.param_groups): + raise ValueError("Expected {} lr_lambdas, but got {}".format( + len(optimizer.param_groups), len(lr_lambda))) + self.lr_lambdas = list(lr_lambda) + super().__init__(optimizer, last_epoch, verbose) + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + The learning rate lambda functions will only be saved if they are callable objects + and not if they are functions or lambdas. + + When saving or loading the scheduler, please make sure to also save or load the state of the optimizer. + """ + + state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', 'lr_lambdas')} + state_dict['lr_lambdas'] = [None] * len(self.lr_lambdas) + + for idx, fn in enumerate(self.lr_lambdas): + if not isinstance(fn, types.FunctionType): + state_dict['lr_lambdas'][idx] = fn.__dict__.copy() + + return state_dict + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + + When saving or loading the scheduler, please make sure to also save or load the state of the optimizer. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + + lr_lambdas = state_dict.pop('lr_lambdas') + self.__dict__.update(state_dict) + # Restore state_dict keys in order to prevent side effects + # https://github.com/pytorch/pytorch/issues/32756 + state_dict['lr_lambdas'] = lr_lambdas + + for idx, fn in enumerate(lr_lambdas): + if fn is not None: + self.lr_lambdas[idx].__dict__.update(fn) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn("To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.") + return [base_lr * lmbda(self.last_epoch) + for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)] + + +class MultiplicativeLR(LRScheduler): + """Multiply the learning rate of each parameter group by the factor given + in the specified function. When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + lr_lambda (function or list): A function which computes a multiplicative + factor given an integer parameter epoch, or a list of such + functions, one for each group in optimizer.param_groups. + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + Example: + >>> # xdoctest: +SKIP + >>> lmbda = lambda epoch: 0.95 + >>> scheduler = MultiplicativeLR(optimizer, lr_lambda=lmbda) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__(self, optimizer, lr_lambda, last_epoch=-1, verbose=False): + self.optimizer = optimizer + + if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple): + self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups) + else: + if len(lr_lambda) != len(optimizer.param_groups): + raise ValueError("Expected {} lr_lambdas, but got {}".format( + len(optimizer.param_groups), len(lr_lambda))) + self.lr_lambdas = list(lr_lambda) + super().__init__(optimizer, last_epoch, verbose) + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + The learning rate lambda functions will only be saved if they are callable objects + and not if they are functions or lambdas. + """ + state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', 'lr_lambdas')} + state_dict['lr_lambdas'] = [None] * len(self.lr_lambdas) + + for idx, fn in enumerate(self.lr_lambdas): + if not isinstance(fn, types.FunctionType): + state_dict['lr_lambdas'][idx] = fn.__dict__.copy() + + return state_dict + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + lr_lambdas = state_dict.pop('lr_lambdas') + self.__dict__.update(state_dict) + # Restore state_dict keys in order to prevent side effects + # https://github.com/pytorch/pytorch/issues/32756 + state_dict['lr_lambdas'] = lr_lambdas + + for idx, fn in enumerate(lr_lambdas): + if fn is not None: + self.lr_lambdas[idx].__dict__.update(fn) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn("To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", UserWarning) + + if self.last_epoch > 0: + return [group['lr'] * lmbda(self.last_epoch) + for lmbda, group in zip(self.lr_lambdas, self.optimizer.param_groups)] + else: + return [group['lr'] for group in self.optimizer.param_groups] + + +class StepLR(LRScheduler): + """Decays the learning rate of each parameter group by gamma every + step_size epochs. Notice that such decay can happen simultaneously with + other changes to the learning rate from outside this scheduler. When + last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + step_size (int): Period of learning rate decay. + gamma (float): Multiplicative factor of learning rate decay. + Default: 0.1. + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + Example: + >>> # xdoctest: +SKIP + >>> # Assuming optimizer uses lr = 0.05 for all groups + >>> # lr = 0.05 if epoch < 30 + >>> # lr = 0.005 if 30 <= epoch < 60 + >>> # lr = 0.0005 if 60 <= epoch < 90 + >>> # ... + >>> scheduler = StepLR(optimizer, step_size=30, gamma=0.1) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__(self, optimizer, step_size, gamma=0.1, last_epoch=-1, verbose=False): + self.step_size = step_size + self.gamma = gamma + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn("To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", UserWarning) + + if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0): + return [group['lr'] for group in self.optimizer.param_groups] + return [group['lr'] * self.gamma + for group in self.optimizer.param_groups] + + def _get_closed_form_lr(self): + return [base_lr * self.gamma ** (self.last_epoch // self.step_size) + for base_lr in self.base_lrs] + + +class MultiStepLR(LRScheduler): + """Decays the learning rate of each parameter group by gamma once the + number of epoch reaches one of the milestones. Notice that such decay can + happen simultaneously with other changes to the learning rate from outside + this scheduler. When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + milestones (list): List of epoch indices. Must be increasing. + gamma (float): Multiplicative factor of learning rate decay. + Default: 0.1. + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + Example: + >>> # xdoctest: +SKIP + >>> # Assuming optimizer uses lr = 0.05 for all groups + >>> # lr = 0.05 if epoch < 30 + >>> # lr = 0.005 if 30 <= epoch < 80 + >>> # lr = 0.0005 if epoch >= 80 + >>> scheduler = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1, verbose=False): + self.milestones = Counter(milestones) + self.gamma = gamma + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn("To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", UserWarning) + + if self.last_epoch not in self.milestones: + return [group['lr'] for group in self.optimizer.param_groups] + return [group['lr'] * self.gamma ** self.milestones[self.last_epoch] + for group in self.optimizer.param_groups] + + def _get_closed_form_lr(self): + milestones = sorted(self.milestones.elements()) + return [base_lr * self.gamma ** bisect_right(milestones, self.last_epoch) + for base_lr in self.base_lrs] + + +class ConstantLR(LRScheduler): + """Decays the learning rate of each parameter group by a small constant factor until the + number of epoch reaches a pre-defined milestone: total_iters. Notice that such decay can + happen simultaneously with other changes to the learning rate from outside this scheduler. + When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + factor (float): The number we multiply learning rate until the milestone. Default: 1./3. + total_iters (int): The number of steps that the scheduler decays the learning rate. + Default: 5. + last_epoch (int): The index of the last epoch. Default: -1. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + Example: + >>> # xdoctest: +SKIP + >>> # Assuming optimizer uses lr = 0.05 for all groups + >>> # lr = 0.025 if epoch == 0 + >>> # lr = 0.025 if epoch == 1 + >>> # lr = 0.025 if epoch == 2 + >>> # lr = 0.025 if epoch == 3 + >>> # lr = 0.05 if epoch >= 4 + >>> scheduler = ConstantLR(self.opt, factor=0.5, total_iters=4) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__(self, optimizer, factor=1.0 / 3, total_iters=5, last_epoch=-1, verbose=False): + if factor > 1.0 or factor < 0: + raise ValueError('Constant multiplicative factor expected to be between 0 and 1.') + + self.factor = factor + self.total_iters = total_iters + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn("To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", UserWarning) + + if self.last_epoch == 0: + return [group['lr'] * self.factor for group in self.optimizer.param_groups] + + if (self.last_epoch > self.total_iters or + (self.last_epoch != self.total_iters)): + return [group['lr'] for group in self.optimizer.param_groups] + + if (self.last_epoch == self.total_iters): + return [group['lr'] * (1.0 / self.factor) for group in self.optimizer.param_groups] + + def _get_closed_form_lr(self): + return [base_lr * (self.factor + (self.last_epoch >= self.total_iters) * (1 - self.factor)) + for base_lr in self.base_lrs] + + +class LinearLR(LRScheduler): + """Decays the learning rate of each parameter group by linearly changing small + multiplicative factor until the number of epoch reaches a pre-defined milestone: total_iters. + Notice that such decay can happen simultaneously with other changes to the learning rate + from outside this scheduler. When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + start_factor (float): The number we multiply learning rate in the first epoch. + The multiplication factor changes towards end_factor in the following epochs. + Default: 1./3. + end_factor (float): The number we multiply learning rate at the end of linear changing + process. Default: 1.0. + total_iters (int): The number of iterations that multiplicative factor reaches to 1. + Default: 5. + last_epoch (int): The index of the last epoch. Default: -1. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + Example: + >>> # xdoctest: +SKIP + >>> # Assuming optimizer uses lr = 0.05 for all groups + >>> # lr = 0.025 if epoch == 0 + >>> # lr = 0.03125 if epoch == 1 + >>> # lr = 0.0375 if epoch == 2 + >>> # lr = 0.04375 if epoch == 3 + >>> # lr = 0.05 if epoch >= 4 + >>> scheduler = LinearLR(self.opt, start_factor=0.5, total_iters=4) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__(self, optimizer, start_factor=1.0 / 3, end_factor=1.0, total_iters=5, last_epoch=-1, + verbose=False): + if start_factor > 1.0 or start_factor <= 0: + raise ValueError('Starting multiplicative factor expected to be greater than 0 and less or equal to 1.') + + if end_factor > 1.0 or end_factor < 0: + raise ValueError('Ending multiplicative factor expected to be between 0 and 1.') + + self.start_factor = start_factor + self.end_factor = end_factor + self.total_iters = total_iters + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn("To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", UserWarning) + + if self.last_epoch == 0: + return [group['lr'] * self.start_factor for group in self.optimizer.param_groups] + + if self.last_epoch > self.total_iters: + return [group['lr'] for group in self.optimizer.param_groups] + + return [group['lr'] * (1. + (self.end_factor - self.start_factor) / + (self.total_iters * self.start_factor + (self.last_epoch - 1) * (self.end_factor - self.start_factor))) + for group in self.optimizer.param_groups] + + def _get_closed_form_lr(self): + return [base_lr * (self.start_factor + + (self.end_factor - self.start_factor) * min(self.total_iters, self.last_epoch) / self.total_iters) + for base_lr in self.base_lrs] + + +class ExponentialLR(LRScheduler): + """Decays the learning rate of each parameter group by gamma every epoch. + When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + gamma (float): Multiplicative factor of learning rate decay. + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + """ + + def __init__(self, optimizer, gamma, last_epoch=-1, verbose=False): + self.gamma = gamma + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn("To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", UserWarning) + + if self.last_epoch == 0: + return [group['lr'] for group in self.optimizer.param_groups] + return [group['lr'] * self.gamma + for group in self.optimizer.param_groups] + + def _get_closed_form_lr(self): + return [base_lr * self.gamma ** self.last_epoch + for base_lr in self.base_lrs] + + +class SequentialLR(LRScheduler): + """Receives the list of schedulers that is expected to be called sequentially during + optimization process and milestone points that provides exact intervals to reflect + which scheduler is supposed to be called at a given epoch. + + Args: + optimizer (Optimizer): Wrapped optimizer. + schedulers (list): List of chained schedulers. + milestones (list): List of integers that reflects milestone points. + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool): Does nothing. + + Example: + >>> # xdoctest: +SKIP + >>> # Assuming optimizer uses lr = 1. for all groups + >>> # lr = 0.1 if epoch == 0 + >>> # lr = 0.1 if epoch == 1 + >>> # lr = 0.9 if epoch == 2 + >>> # lr = 0.81 if epoch == 3 + >>> # lr = 0.729 if epoch == 4 + >>> scheduler1 = ConstantLR(self.opt, factor=0.1, total_iters=2) + >>> scheduler2 = ExponentialLR(self.opt, gamma=0.9) + >>> scheduler = SequentialLR(self.opt, schedulers=[scheduler1, scheduler2], milestones=[2]) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__(self, optimizer, schedulers, milestones, last_epoch=-1, verbose=False): + for scheduler_idx in range(len(schedulers)): + if schedulers[scheduler_idx].optimizer != optimizer: + raise ValueError( + "Sequential Schedulers expects all schedulers to belong to the same optimizer, but " + f"got schedulers at index {scheduler_idx} to be different than the optimizer passed in." + ) + + if (schedulers[scheduler_idx].optimizer != schedulers[0].optimizer): + raise ValueError( + "Sequential Schedulers expects all schedulers to belong to the same optimizer, but " + f"got schedulers at index {0} and {scheduler_idx} to be different." + ) + if (len(milestones) != len(schedulers) - 1): + raise ValueError( + "Sequential Schedulers expects number of schedulers provided to be one more " + "than the number of milestone points, but got number of schedulers {} and the " + "number of milestones to be equal to {}".format(len(schedulers), len(milestones)) + ) + self._schedulers = schedulers + self._milestones = milestones + self.last_epoch = last_epoch + 1 + self.optimizer = optimizer + + # Reset learning rates back to initial values + for group in self.optimizer.param_groups: + group["lr"] = group["initial_lr"] + + # "Undo" the step performed by other schedulers + for scheduler in self._schedulers: + scheduler.last_epoch -= 1 + + # Perform the initial step for only the first scheduler + self._schedulers[0]._initial_step() + + self._last_lr = schedulers[0].get_last_lr() + + def step(self): + self.last_epoch += 1 + idx = bisect_right(self._milestones, self.last_epoch) + scheduler = self._schedulers[idx] + if idx > 0 and self._milestones[idx - 1] == self.last_epoch: + scheduler.step(0) + else: + scheduler.step() + + self._last_lr = scheduler.get_last_lr() + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + The wrapped scheduler states will also be saved. + """ + state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', '_schedulers')} + state_dict['_schedulers'] = [None] * len(self._schedulers) + + for idx, s in enumerate(self._schedulers): + state_dict['_schedulers'][idx] = s.state_dict() + + return state_dict + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + _schedulers = state_dict.pop('_schedulers') + self.__dict__.update(state_dict) + # Restore state_dict keys in order to prevent side effects + # https://github.com/pytorch/pytorch/issues/32756 + state_dict['_schedulers'] = _schedulers + + for idx, s in enumerate(_schedulers): + self._schedulers[idx].load_state_dict(s) + + +class PolynomialLR(LRScheduler): + """Decays the learning rate of each parameter group using a polynomial function + in the given total_iters. When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + total_iters (int): The number of steps that the scheduler decays the learning rate. Default: 5. + power (int): The power of the polynomial. Default: 1.0. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + Example: + >>> # xdoctest: +SKIP("undefined vars") + >>> # Assuming optimizer uses lr = 0.001 for all groups + >>> # lr = 0.001 if epoch == 0 + >>> # lr = 0.00075 if epoch == 1 + >>> # lr = 0.00050 if epoch == 2 + >>> # lr = 0.00025 if epoch == 3 + >>> # lr = 0.0 if epoch >= 4 + >>> scheduler = PolynomialLR(self.opt, total_iters=4, power=1.0) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + def __init__(self, optimizer, total_iters=5, power=1.0, last_epoch=-1, verbose=False): + self.total_iters = total_iters + self.power = power + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn("To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", UserWarning) + + if self.last_epoch == 0 or self.last_epoch > self.total_iters: + return [group["lr"] for group in self.optimizer.param_groups] + + decay_factor = ((1.0 - self.last_epoch / self.total_iters) / (1.0 - (self.last_epoch - 1) / self.total_iters)) ** self.power + return [group["lr"] * decay_factor for group in self.optimizer.param_groups] + + def _get_closed_form_lr(self): + return [ + ( + base_lr * (1.0 - min(self.total_iters, self.last_epoch) / self.total_iters) ** self.power + ) + for base_lr in self.base_lrs + ] + + +class CosineAnnealingLR(LRScheduler): + r"""Set the learning rate of each parameter group using a cosine annealing + schedule, where :math:`\eta_{max}` is set to the initial lr and + :math:`T_{cur}` is the number of epochs since the last restart in SGDR: + + .. math:: + \begin{aligned} + \eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right), + & T_{cur} \neq (2k+1)T_{max}; \\ + \eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min}) + \left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right), + & T_{cur} = (2k+1)T_{max}. + \end{aligned} + + When last_epoch=-1, sets initial lr as lr. Notice that because the schedule + is defined recursively, the learning rate can be simultaneously modified + outside this scheduler by other operators. If the learning rate is set + solely by this scheduler, the learning rate at each step becomes: + + .. math:: + \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right) + + It has been proposed in + `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only + implements the cosine annealing part of SGDR, and not the restarts. + + Args: + optimizer (Optimizer): Wrapped optimizer. + T_max (int): Maximum number of iterations. + eta_min (float): Minimum learning rate. Default: 0. + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: + https://arxiv.org/abs/1608.03983 + """ + + def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1, verbose=False): + self.T_max = T_max + self.eta_min = eta_min + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn("To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", UserWarning) + + if self.last_epoch == 0: + return [group['lr'] for group in self.optimizer.param_groups] + elif self._step_count == 1 and self.last_epoch > 0: + return [self.eta_min + (base_lr - self.eta_min) * + (1 + math.cos((self.last_epoch) * math.pi / self.T_max)) / 2 + for base_lr, group in + zip(self.base_lrs, self.optimizer.param_groups)] + elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0: + return [group['lr'] + (base_lr - self.eta_min) * + (1 - math.cos(math.pi / self.T_max)) / 2 + for base_lr, group in + zip(self.base_lrs, self.optimizer.param_groups)] + return [(1 + math.cos(math.pi * self.last_epoch / self.T_max)) / + (1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)) * + (group['lr'] - self.eta_min) + self.eta_min + for group in self.optimizer.param_groups] + + def _get_closed_form_lr(self): + return [self.eta_min + (base_lr - self.eta_min) * + (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2 + for base_lr in self.base_lrs] + + +class ChainedScheduler(LRScheduler): + """Chains list of learning rate schedulers. It takes a list of chainable learning + rate schedulers and performs consecutive step() functions belonging to them by just + one call. + + Args: + schedulers (list): List of chained schedulers. + + Example: + >>> # xdoctest: +SKIP + >>> # Assuming optimizer uses lr = 1. for all groups + >>> # lr = 0.09 if epoch == 0 + >>> # lr = 0.081 if epoch == 1 + >>> # lr = 0.729 if epoch == 2 + >>> # lr = 0.6561 if epoch == 3 + >>> # lr = 0.59049 if epoch >= 4 + >>> scheduler1 = ConstantLR(self.opt, factor=0.1, total_iters=2) + >>> scheduler2 = ExponentialLR(self.opt, gamma=0.9) + >>> scheduler = ChainedScheduler([scheduler1, scheduler2]) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__(self, schedulers): + for scheduler_idx in range(1, len(schedulers)): + if (schedulers[scheduler_idx].optimizer != schedulers[0].optimizer): + raise ValueError( + "ChainedScheduler expects all schedulers to belong to the same optimizer, but " + "got schedulers at index {} and {} to be different".format(0, scheduler_idx) + ) + self._schedulers = list(schedulers) + self.optimizer = schedulers[0].optimizer + self._last_lr = [group['lr'] for group in self._schedulers[-1].optimizer.param_groups] + + def step(self): + for scheduler in self._schedulers: + scheduler.step() + self._last_lr = [group['lr'] for group in self._schedulers[-1].optimizer.param_groups] + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + The wrapped scheduler states will also be saved. + """ + state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', '_schedulers')} + state_dict['_schedulers'] = [None] * len(self._schedulers) + + for idx, s in enumerate(self._schedulers): + state_dict['_schedulers'][idx] = s.state_dict() + + return state_dict + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + _schedulers = state_dict.pop('_schedulers') + self.__dict__.update(state_dict) + # Restore state_dict keys in order to prevent side effects + # https://github.com/pytorch/pytorch/issues/32756 + state_dict['_schedulers'] = _schedulers + + for idx, s in enumerate(_schedulers): + self._schedulers[idx].load_state_dict(s) + + +class ReduceLROnPlateau: + """Reduce learning rate when a metric has stopped improving. + Models often benefit from reducing the learning rate by a factor + of 2-10 once learning stagnates. This scheduler reads a metrics + quantity and if no improvement is seen for a 'patience' number + of epochs, the learning rate is reduced. + + Args: + optimizer (Optimizer): Wrapped optimizer. + mode (str): One of `min`, `max`. In `min` mode, lr will + be reduced when the quantity monitored has stopped + decreasing; in `max` mode it will be reduced when the + quantity monitored has stopped increasing. Default: 'min'. + factor (float): Factor by which the learning rate will be + reduced. new_lr = lr * factor. Default: 0.1. + patience (int): Number of epochs with no improvement after + which learning rate will be reduced. For example, if + `patience = 2`, then we will ignore the first 2 epochs + with no improvement, and will only decrease the LR after the + 3rd epoch if the loss still hasn't improved then. + Default: 10. + threshold (float): Threshold for measuring the new optimum, + to only focus on significant changes. Default: 1e-4. + threshold_mode (str): One of `rel`, `abs`. In `rel` mode, + dynamic_threshold = best * ( 1 + threshold ) in 'max' + mode or best * ( 1 - threshold ) in `min` mode. + In `abs` mode, dynamic_threshold = best + threshold in + `max` mode or best - threshold in `min` mode. Default: 'rel'. + cooldown (int): Number of epochs to wait before resuming + normal operation after lr has been reduced. Default: 0. + min_lr (float or list): A scalar or a list of scalars. A + lower bound on the learning rate of all param groups + or each group respectively. Default: 0. + eps (float): Minimal decay applied to lr. If the difference + between new and old lr is smaller than eps, the update is + ignored. Default: 1e-8. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + Example: + >>> # xdoctest: +SKIP + >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) + >>> scheduler = ReduceLROnPlateau(optimizer, 'min') + >>> for epoch in range(10): + >>> train(...) + >>> val_loss = validate(...) + >>> # Note that step should be called after validate() + >>> scheduler.step(val_loss) + """ + + def __init__(self, optimizer, mode='min', factor=0.1, patience=10, + threshold=1e-4, threshold_mode='rel', cooldown=0, + min_lr=0, eps=1e-8, verbose=False): + + if factor >= 1.0: + raise ValueError('Factor should be < 1.0.') + self.factor = factor + + # Attach optimizer + if not isinstance(optimizer, Optimizer): + raise TypeError('{} is not an Optimizer'.format( + type(optimizer).__name__)) + self.optimizer = optimizer + + if isinstance(min_lr, (list, tuple)): + if len(min_lr) != len(optimizer.param_groups): + raise ValueError("expected {} min_lrs, got {}".format( + len(optimizer.param_groups), len(min_lr))) + self.min_lrs = list(min_lr) + else: + self.min_lrs = [min_lr] * len(optimizer.param_groups) + + self.patience = patience + self.verbose = verbose + self.cooldown = cooldown + self.cooldown_counter = 0 + self.mode = mode + self.threshold = threshold + self.threshold_mode = threshold_mode + self.best = None + self.num_bad_epochs = None + self.mode_worse = None # the worse value for the chosen mode + self.eps = eps + self.last_epoch = 0 + self._init_is_better(mode=mode, threshold=threshold, + threshold_mode=threshold_mode) + self._reset() + + def _reset(self): + """Resets num_bad_epochs counter and cooldown counter.""" + self.best = self.mode_worse + self.cooldown_counter = 0 + self.num_bad_epochs = 0 + + def step(self, metrics, epoch=None): + # convert `metrics` to float, in case it's a zero-dim Tensor + current = float(metrics) + if epoch is None: + epoch = self.last_epoch + 1 + else: + warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning) + self.last_epoch = epoch + + if self.is_better(current, self.best): + self.best = current + self.num_bad_epochs = 0 + else: + self.num_bad_epochs += 1 + + if self.in_cooldown: + self.cooldown_counter -= 1 + self.num_bad_epochs = 0 # ignore any bad epochs in cooldown + + if self.num_bad_epochs > self.patience: + self._reduce_lr(epoch) + self.cooldown_counter = self.cooldown + self.num_bad_epochs = 0 + + self._last_lr = [group['lr'] for group in self.optimizer.param_groups] + + def _reduce_lr(self, epoch): + for i, param_group in enumerate(self.optimizer.param_groups): + old_lr = float(param_group['lr']) + new_lr = max(old_lr * self.factor, self.min_lrs[i]) + if old_lr - new_lr > self.eps: + param_group['lr'] = new_lr + if self.verbose: + epoch_str = ("%.2f" if isinstance(epoch, float) else + "%.5d") % epoch + print('Epoch {}: reducing learning rate' + ' of group {} to {:.4e}.'.format(epoch_str, i, new_lr)) + + @property + def in_cooldown(self): + return self.cooldown_counter > 0 + + def is_better(self, a, best): + if self.mode == 'min' and self.threshold_mode == 'rel': + rel_epsilon = 1. - self.threshold + return a < best * rel_epsilon + + elif self.mode == 'min' and self.threshold_mode == 'abs': + return a < best - self.threshold + + elif self.mode == 'max' and self.threshold_mode == 'rel': + rel_epsilon = self.threshold + 1. + return a > best * rel_epsilon + + else: # mode == 'max' and epsilon_mode == 'abs': + return a > best + self.threshold + + def _init_is_better(self, mode, threshold, threshold_mode): + if mode not in {'min', 'max'}: + raise ValueError('mode ' + mode + ' is unknown!') + if threshold_mode not in {'rel', 'abs'}: + raise ValueError('threshold mode ' + threshold_mode + ' is unknown!') + + if mode == 'min': + self.mode_worse = inf + else: # mode == 'max': + self.mode_worse = -inf + + self.mode = mode + self.threshold = threshold + self.threshold_mode = threshold_mode + + def state_dict(self): + return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} + + def load_state_dict(self, state_dict): + self.__dict__.update(state_dict) + self._init_is_better(mode=self.mode, threshold=self.threshold, threshold_mode=self.threshold_mode) + + +class CyclicLR(LRScheduler): + r"""Sets the learning rate of each parameter group according to + cyclical learning rate policy (CLR). The policy cycles the learning + rate between two boundaries with a constant frequency, as detailed in + the paper `Cyclical Learning Rates for Training Neural Networks`_. + The distance between the two boundaries can be scaled on a per-iteration + or per-cycle basis. + + Cyclical learning rate policy changes the learning rate after every batch. + `step` should be called after a batch has been used for training. + + This class has three built-in policies, as put forth in the paper: + + * "triangular": A basic triangular cycle without amplitude scaling. + * "triangular2": A basic triangular cycle that scales initial amplitude by half each cycle. + * "exp_range": A cycle that scales initial amplitude by :math:`\text{gamma}^{\text{cycle iterations}}` + at each cycle iteration. + + This implementation was adapted from the github repo: `bckenstler/CLR`_ + + Args: + optimizer (Optimizer): Wrapped optimizer. + base_lr (float or list): Initial learning rate which is the + lower boundary in the cycle for each parameter group. + max_lr (float or list): Upper learning rate boundaries in the cycle + for each parameter group. Functionally, + it defines the cycle amplitude (max_lr - base_lr). + The lr at any cycle is the sum of base_lr + and some scaling of the amplitude; therefore + max_lr may not actually be reached depending on + scaling function. + step_size_up (int): Number of training iterations in the + increasing half of a cycle. Default: 2000 + step_size_down (int): Number of training iterations in the + decreasing half of a cycle. If step_size_down is None, + it is set to step_size_up. Default: None + mode (str): One of {triangular, triangular2, exp_range}. + Values correspond to policies detailed above. + If scale_fn is not None, this argument is ignored. + Default: 'triangular' + gamma (float): Constant in 'exp_range' scaling function: + gamma**(cycle iterations) + Default: 1.0 + scale_fn (function): Custom scaling policy defined by a single + argument lambda function, where + 0 <= scale_fn(x) <= 1 for all x >= 0. + If specified, then 'mode' is ignored. + Default: None + scale_mode (str): {'cycle', 'iterations'}. + Defines whether scale_fn is evaluated on + cycle number or cycle iterations (training + iterations since start of cycle). + Default: 'cycle' + cycle_momentum (bool): If ``True``, momentum is cycled inversely + to learning rate between 'base_momentum' and 'max_momentum'. + Default: True + base_momentum (float or list): Lower momentum boundaries in the cycle + for each parameter group. Note that momentum is cycled inversely + to learning rate; at the peak of a cycle, momentum is + 'base_momentum' and learning rate is 'max_lr'. + Default: 0.8 + max_momentum (float or list): Upper momentum boundaries in the cycle + for each parameter group. Functionally, + it defines the cycle amplitude (max_momentum - base_momentum). + The momentum at any cycle is the difference of max_momentum + and some scaling of the amplitude; therefore + base_momentum may not actually be reached depending on + scaling function. Note that momentum is cycled inversely + to learning rate; at the start of a cycle, momentum is 'max_momentum' + and learning rate is 'base_lr' + Default: 0.9 + last_epoch (int): The index of the last batch. This parameter is used when + resuming a training job. Since `step()` should be invoked after each + batch instead of after each epoch, this number represents the total + number of *batches* computed, not the total number of epochs computed. + When last_epoch=-1, the schedule is started from the beginning. + Default: -1 + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + Example: + >>> # xdoctest: +SKIP + >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) + >>> scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.01, max_lr=0.1) + >>> data_loader = torch.utils.data.DataLoader(...) + >>> for epoch in range(10): + >>> for batch in data_loader: + >>> train_batch(...) + >>> scheduler.step() + + + .. _Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186 + .. _bckenstler/CLR: https://github.com/bckenstler/CLR + """ + + def __init__(self, + optimizer, + base_lr, + max_lr, + step_size_up=2000, + step_size_down=None, + mode='triangular', + gamma=1., + scale_fn=None, + scale_mode='cycle', + cycle_momentum=True, + base_momentum=0.8, + max_momentum=0.9, + last_epoch=-1, + verbose=False): + + # Attach optimizer + if not isinstance(optimizer, Optimizer): + raise TypeError('{} is not an Optimizer'.format( + type(optimizer).__name__)) + self.optimizer = optimizer + + base_lrs = self._format_param('base_lr', optimizer, base_lr) + if last_epoch == -1: + for lr, group in zip(base_lrs, optimizer.param_groups): + group['lr'] = lr + + self.max_lrs = self._format_param('max_lr', optimizer, max_lr) + + step_size_up = float(step_size_up) + step_size_down = float(step_size_down) if step_size_down is not None else step_size_up + self.total_size = step_size_up + step_size_down + self.step_ratio = step_size_up / self.total_size + + if mode not in ['triangular', 'triangular2', 'exp_range'] \ + and scale_fn is None: + raise ValueError('mode is invalid and scale_fn is None') + + self.mode = mode + self.gamma = gamma + + self._scale_fn_ref = None + self._scale_fn_custom = scale_fn + self.scale_mode = scale_mode + self._init_scale_fn() + + self.cycle_momentum = cycle_momentum + if cycle_momentum: + if 'momentum' not in optimizer.defaults: + raise ValueError('optimizer must support momentum with `cycle_momentum` option enabled') + + base_momentums = self._format_param('base_momentum', optimizer, base_momentum) + if last_epoch == -1: + for momentum, group in zip(base_momentums, optimizer.param_groups): + group['momentum'] = momentum + self.base_momentums = [group['momentum'] for group in optimizer.param_groups] + self.max_momentums = self._format_param('max_momentum', optimizer, max_momentum) + + super().__init__(optimizer, last_epoch, verbose) + self.base_lrs = base_lrs + + def _init_scale_fn(self): + if self._scale_fn_custom is not None: + return + if self.mode == 'triangular': + self._scale_fn_ref = weakref.WeakMethod(self._triangular_scale_fn) + self.scale_mode = 'cycle' + elif self.mode == 'triangular2': + self._scale_fn_ref = weakref.WeakMethod(self._triangular2_scale_fn) + self.scale_mode = 'cycle' + elif self.mode == 'exp_range': + self._scale_fn_ref = weakref.WeakMethod(self._exp_range_scale_fn) + self.scale_mode = 'iterations' + + def _format_param(self, name, optimizer, param): + """Return correctly formatted lr/momentum for each param group.""" + if isinstance(param, (list, tuple)): + if len(param) != len(optimizer.param_groups): + raise ValueError("expected {} values for {}, got {}".format( + len(optimizer.param_groups), name, len(param))) + return param + else: + return [param] * len(optimizer.param_groups) + + def scale_fn(self, x): + if self._scale_fn_custom is not None: + return self._scale_fn_custom(x) + + else: + return self._scale_fn_ref()(x) + + def _triangular_scale_fn(self, x): + return 1. + + def _triangular2_scale_fn(self, x): + return 1 / (2. ** (x - 1)) + + def _exp_range_scale_fn(self, x): + return self.gamma**(x) + + def get_lr(self): + """Calculates the learning rate at batch index. This function treats + `self.last_epoch` as the last batch index. + + If `self.cycle_momentum` is ``True``, this function has a side effect of + updating the optimizer's momentum. + """ + + if not self._get_lr_called_within_step: + warnings.warn("To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", UserWarning) + + cycle = math.floor(1 + self.last_epoch / self.total_size) + x = 1. + self.last_epoch / self.total_size - cycle + if x <= self.step_ratio: + scale_factor = x / self.step_ratio + else: + scale_factor = (x - 1) / (self.step_ratio - 1) + + lrs = [] + for base_lr, max_lr in zip(self.base_lrs, self.max_lrs): + base_height = (max_lr - base_lr) * scale_factor + if self.scale_mode == 'cycle': + lr = base_lr + base_height * self.scale_fn(cycle) + else: + lr = base_lr + base_height * self.scale_fn(self.last_epoch) + lrs.append(lr) + + if self.cycle_momentum: + momentums = [] + for base_momentum, max_momentum in zip(self.base_momentums, self.max_momentums): + base_height = (max_momentum - base_momentum) * scale_factor + if self.scale_mode == 'cycle': + momentum = max_momentum - base_height * self.scale_fn(cycle) + else: + momentum = max_momentum - base_height * self.scale_fn(self.last_epoch) + momentums.append(momentum) + for param_group, momentum in zip(self.optimizer.param_groups, momentums): + param_group['momentum'] = momentum + + return lrs + + def state_dict(self): + state = super().state_dict() + # We are dropping the `_scale_fn_ref` attribute because it is a `weakref.WeakMethod` and can't be pickled + state.pop("_scale_fn_ref") + return state + + def load_state_dict(self, state_dict): + super().load_state_dict(state_dict) + self._init_scale_fn() + + + +class CosineAnnealingWarmRestarts(LRScheduler): + r"""Set the learning rate of each parameter group using a cosine annealing + schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}` + is the number of epochs since the last restart and :math:`T_{i}` is the number + of epochs between two warm restarts in SGDR: + + .. math:: + \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + + \cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right) + + When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`. + When :math:`T_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`. + + It has been proposed in + `SGDR: Stochastic Gradient Descent with Warm Restarts`_. + + Args: + optimizer (Optimizer): Wrapped optimizer. + T_0 (int): Number of iterations for the first restart. + T_mult (int, optional): A factor increases :math:`T_{i}` after a restart. Default: 1. + eta_min (float, optional): Minimum learning rate. Default: 0. + last_epoch (int, optional): The index of last epoch. Default: -1. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: + https://arxiv.org/abs/1608.03983 + """ + + def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1, verbose=False): + if T_0 <= 0 or not isinstance(T_0, int): + raise ValueError("Expected positive integer T_0, but got {}".format(T_0)) + if T_mult < 1 or not isinstance(T_mult, int): + raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult)) + self.T_0 = T_0 + self.T_i = T_0 + self.T_mult = T_mult + self.eta_min = eta_min + self.T_cur = last_epoch + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn("To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", UserWarning) + + return [self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * self.T_cur / self.T_i)) / 2 + for base_lr in self.base_lrs] + + def step(self, epoch=None): + """Step could be called after every batch update + + Example: + >>> # xdoctest: +SKIP("Undefined vars") + >>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult) + >>> iters = len(dataloader) + >>> for epoch in range(20): + >>> for i, sample in enumerate(dataloader): + >>> inputs, labels = sample['inputs'], sample['labels'] + >>> optimizer.zero_grad() + >>> outputs = net(inputs) + >>> loss = criterion(outputs, labels) + >>> loss.backward() + >>> optimizer.step() + >>> scheduler.step(epoch + i / iters) + + This function can be called in an interleaved way. + + Example: + >>> # xdoctest: +SKIP("Undefined vars") + >>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult) + >>> for epoch in range(20): + >>> scheduler.step() + >>> scheduler.step(26) + >>> scheduler.step() # scheduler.step(27), instead of scheduler(20) + """ + + if epoch is None and self.last_epoch < 0: + epoch = 0 + + if epoch is None: + epoch = self.last_epoch + 1 + self.T_cur = self.T_cur + 1 + if self.T_cur >= self.T_i: + self.T_cur = self.T_cur - self.T_i + self.T_i = self.T_i * self.T_mult + else: + if epoch < 0: + raise ValueError("Expected non-negative epoch, but got {}".format(epoch)) + if epoch >= self.T_0: + if self.T_mult == 1: + self.T_cur = epoch % self.T_0 + else: + n = int(math.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult)) + self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1) + self.T_i = self.T_0 * self.T_mult ** (n) + else: + self.T_i = self.T_0 + self.T_cur = epoch + self.last_epoch = math.floor(epoch) + + class _enable_get_lr_call: + + def __init__(self, o): + self.o = o + + def __enter__(self): + self.o._get_lr_called_within_step = True + return self + + def __exit__(self, type, value, traceback): + self.o._get_lr_called_within_step = False + return self + + with _enable_get_lr_call(self): + for i, data in enumerate(zip(self.optimizer.param_groups, self.get_lr())): + param_group, lr = data + param_group['lr'] = lr + self.print_lr(self.verbose, i, lr, epoch) + + self._last_lr = [group['lr'] for group in self.optimizer.param_groups] + + +class OneCycleLR(LRScheduler): + r"""Sets the learning rate of each parameter group according to the + 1cycle learning rate policy. The 1cycle policy anneals the learning + rate from an initial learning rate to some maximum learning rate and then + from that maximum learning rate to some minimum learning rate much lower + than the initial learning rate. + This policy was initially described in the paper `Super-Convergence: + Very Fast Training of Neural Networks Using Large Learning Rates`_. + + The 1cycle learning rate policy changes the learning rate after every batch. + `step` should be called after a batch has been used for training. + + This scheduler is not chainable. + + Note also that the total number of steps in the cycle can be determined in one + of two ways (listed in order of precedence): + + #. A value for total_steps is explicitly provided. + #. A number of epochs (epochs) and a number of steps per epoch + (steps_per_epoch) are provided. + In this case, the number of total steps is inferred by + total_steps = epochs * steps_per_epoch + + You must either provide a value for total_steps or provide a value for both + epochs and steps_per_epoch. + + The default behaviour of this scheduler follows the fastai implementation of 1cycle, which + claims that "unpublished work has shown even better results by using only two phases". To + mimic the behaviour of the original paper instead, set ``three_phase=True``. + + Args: + optimizer (Optimizer): Wrapped optimizer. + max_lr (float or list): Upper learning rate boundaries in the cycle + for each parameter group. + total_steps (int): The total number of steps in the cycle. Note that + if a value is not provided here, then it must be inferred by providing + a value for epochs and steps_per_epoch. + Default: None + epochs (int): The number of epochs to train for. This is used along + with steps_per_epoch in order to infer the total number of steps in the cycle + if a value for total_steps is not provided. + Default: None + steps_per_epoch (int): The number of steps per epoch to train for. This is + used along with epochs in order to infer the total number of steps in the + cycle if a value for total_steps is not provided. + Default: None + pct_start (float): The percentage of the cycle (in number of steps) spent + increasing the learning rate. + Default: 0.3 + anneal_strategy (str): {'cos', 'linear'} + Specifies the annealing strategy: "cos" for cosine annealing, "linear" for + linear annealing. + Default: 'cos' + cycle_momentum (bool): If ``True``, momentum is cycled inversely + to learning rate between 'base_momentum' and 'max_momentum'. + Default: True + base_momentum (float or list): Lower momentum boundaries in the cycle + for each parameter group. Note that momentum is cycled inversely + to learning rate; at the peak of a cycle, momentum is + 'base_momentum' and learning rate is 'max_lr'. + Default: 0.85 + max_momentum (float or list): Upper momentum boundaries in the cycle + for each parameter group. Functionally, + it defines the cycle amplitude (max_momentum - base_momentum). + Note that momentum is cycled inversely + to learning rate; at the start of a cycle, momentum is 'max_momentum' + and learning rate is 'base_lr' + Default: 0.95 + div_factor (float): Determines the initial learning rate via + initial_lr = max_lr/div_factor + Default: 25 + final_div_factor (float): Determines the minimum learning rate via + min_lr = initial_lr/final_div_factor + Default: 1e4 + three_phase (bool): If ``True``, use a third phase of the schedule to annihilate the + learning rate according to 'final_div_factor' instead of modifying the second + phase (the first two phases will be symmetrical about the step indicated by + 'pct_start'). + last_epoch (int): The index of the last batch. This parameter is used when + resuming a training job. Since `step()` should be invoked after each + batch instead of after each epoch, this number represents the total + number of *batches* computed, not the total number of epochs computed. + When last_epoch=-1, the schedule is started from the beginning. + Default: -1 + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + Example: + >>> # xdoctest: +SKIP + >>> data_loader = torch.utils.data.DataLoader(...) + >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) + >>> scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=len(data_loader), epochs=10) + >>> for epoch in range(10): + >>> for batch in data_loader: + >>> train_batch(...) + >>> optimizer.step() + >>> scheduler.step() + + + .. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates: + https://arxiv.org/abs/1708.07120 + """ + def __init__(self, + optimizer, + max_lr, + total_steps=None, + epochs=None, + steps_per_epoch=None, + pct_start=0.3, + anneal_strategy='cos', + cycle_momentum=True, + base_momentum=0.85, + max_momentum=0.95, + div_factor=25., + final_div_factor=1e4, + three_phase=False, + last_epoch=-1, + verbose=False): + + # Validate optimizer + if not isinstance(optimizer, Optimizer): + raise TypeError('{} is not an Optimizer'.format( + type(optimizer).__name__)) + self.optimizer = optimizer + + # Validate total_steps + if total_steps is None and epochs is None and steps_per_epoch is None: + raise ValueError("You must define either total_steps OR (epochs AND steps_per_epoch)") + elif total_steps is not None: + if total_steps <= 0 or not isinstance(total_steps, int): + raise ValueError("Expected positive integer total_steps, but got {}".format(total_steps)) + self.total_steps = total_steps + else: + if epochs <= 0 or not isinstance(epochs, int): + raise ValueError("Expected positive integer epochs, but got {}".format(epochs)) + if steps_per_epoch <= 0 or not isinstance(steps_per_epoch, int): + raise ValueError("Expected positive integer steps_per_epoch, but got {}".format(steps_per_epoch)) + self.total_steps = epochs * steps_per_epoch + + if three_phase: + self._schedule_phases = [ + { + 'end_step': float(pct_start * self.total_steps) - 1, + 'start_lr': 'initial_lr', + 'end_lr': 'max_lr', + 'start_momentum': 'max_momentum', + 'end_momentum': 'base_momentum', + }, + { + 'end_step': float(2 * pct_start * self.total_steps) - 2, + 'start_lr': 'max_lr', + 'end_lr': 'initial_lr', + 'start_momentum': 'base_momentum', + 'end_momentum': 'max_momentum', + }, + { + 'end_step': self.total_steps - 1, + 'start_lr': 'initial_lr', + 'end_lr': 'min_lr', + 'start_momentum': 'max_momentum', + 'end_momentum': 'max_momentum', + }, + ] + else: + self._schedule_phases = [ + { + 'end_step': float(pct_start * self.total_steps) - 1, + 'start_lr': 'initial_lr', + 'end_lr': 'max_lr', + 'start_momentum': 'max_momentum', + 'end_momentum': 'base_momentum', + }, + { + 'end_step': self.total_steps - 1, + 'start_lr': 'max_lr', + 'end_lr': 'min_lr', + 'start_momentum': 'base_momentum', + 'end_momentum': 'max_momentum', + }, + ] + + # Validate pct_start + if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float): + raise ValueError("Expected float between 0 and 1 pct_start, but got {}".format(pct_start)) + + # Validate anneal_strategy + if anneal_strategy not in ['cos', 'linear']: + raise ValueError("anneal_strategy must by one of 'cos' or 'linear', instead got {}".format(anneal_strategy)) + elif anneal_strategy == 'cos': + self.anneal_func = self._annealing_cos + elif anneal_strategy == 'linear': + self.anneal_func = self._annealing_linear + + # Initialize learning rate variables + max_lrs = self._format_param('max_lr', self.optimizer, max_lr) + if last_epoch == -1: + for idx, group in enumerate(self.optimizer.param_groups): + group['initial_lr'] = max_lrs[idx] / div_factor + group['max_lr'] = max_lrs[idx] + group['min_lr'] = group['initial_lr'] / final_div_factor + + # Initialize momentum variables + self.cycle_momentum = cycle_momentum + if self.cycle_momentum: + if 'momentum' not in self.optimizer.defaults and 'betas' not in self.optimizer.defaults: + raise ValueError('optimizer must support momentum with `cycle_momentum` option enabled') + self.use_beta1 = 'betas' in self.optimizer.defaults + max_momentums = self._format_param('max_momentum', optimizer, max_momentum) + base_momentums = self._format_param('base_momentum', optimizer, base_momentum) + if last_epoch == -1: + for m_momentum, b_momentum, group in zip(max_momentums, base_momentums, optimizer.param_groups): + if self.use_beta1: + group['betas'] = (m_momentum, *group['betas'][1:]) + else: + group['momentum'] = m_momentum + group['max_momentum'] = m_momentum + group['base_momentum'] = b_momentum + + super().__init__(optimizer, last_epoch, verbose) + + def _format_param(self, name, optimizer, param): + """Return correctly formatted lr/momentum for each param group.""" + if isinstance(param, (list, tuple)): + if len(param) != len(optimizer.param_groups): + raise ValueError("expected {} values for {}, got {}".format( + len(optimizer.param_groups), name, len(param))) + return param + else: + return [param] * len(optimizer.param_groups) + + def _annealing_cos(self, start, end, pct): + "Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0." + cos_out = math.cos(math.pi * pct) + 1 + return end + (start - end) / 2.0 * cos_out + + def _annealing_linear(self, start, end, pct): + "Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0." + return (end - start) * pct + start + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn("To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", UserWarning) + + lrs = [] + step_num = self.last_epoch + + if step_num > self.total_steps: + raise ValueError("Tried to step {} times. The specified number of total steps is {}" + .format(step_num, self.total_steps)) + + for group in self.optimizer.param_groups: + start_step = 0 + for i, phase in enumerate(self._schedule_phases): + end_step = phase['end_step'] + if step_num <= end_step or i == len(self._schedule_phases) - 1: + pct = (step_num - start_step) / (end_step - start_step) + computed_lr = self.anneal_func(group[phase['start_lr']], group[phase['end_lr']], pct) + if self.cycle_momentum: + computed_momentum = self.anneal_func(group[phase['start_momentum']], group[phase['end_momentum']], pct) + break + start_step = phase['end_step'] + + lrs.append(computed_lr) + if self.cycle_momentum: + if self.use_beta1: + group['betas'] = (computed_momentum, *group['betas'][1:]) + else: + group['momentum'] = computed_momentum + + return lrs \ No newline at end of file diff --git a/python/jittor/compatibility/src/jtorch_core.cc b/python/jittor/compatibility/src/jtorch_core.cc new file mode 100644 index 00000000..1102b107 --- /dev/null +++ b/python/jittor/compatibility/src/jtorch_core.cc @@ -0,0 +1,102 @@ + +#include "pyjt/py_obj_holder.h" +#include "utils/str_utils.h" +#include "jtorch_core.h" +#include "graph.h" +#include "grad.h" +#include "ops/op_register.h" + +namespace jittor { + +void pyjt_def_all(PyObject* m); + +EXTERN_LIB void setter_use_cuda(int value); + +Device::Device(const string& name, int ordinal) : name(name) { + if (startswith(name, "cpu")) + setter_use_cuda(0); + else + setter_use_cuda(1); +} + +unordered_map grad_backup; +EXTERN_LIB void (*_var_free_hook)(Var*); +EXTERN_LIB unordered_map* _grad_backup_ptr; + +void jtorch_var_free_hook(Var* v) { + auto iter = grad_backup.find(v->id); + if (iter != grad_backup.end()) { + grad_backup.erase(iter); + } +} + +void jtorch_init() { + _var_free_hook = &jtorch_var_free_hook; + _grad_backup_ptr = &grad_backup; +} + +inline static VarPtr& get_grad(Var* v) { + return grad_backup[v->id]; +} +static auto make_binary = get_op_info("binary") + .get_constructor(); + +inline static void add_grad(VarPtr& a, VarPtr&& b) { + if (!a) a = move(b); + else { + a = make_binary(a, b, ns_add); + } +} + + +void grad_set(VarHolder* x, Maybe v) { + if (!v) { + grad_del(x); + return; + } + grad_backup[x->var->id] = v.ptr->var; +} + +Maybe grad_get(VarHolder* x) { + auto iter = grad_backup.find(x->var->id); + if (iter != grad_backup.end()) { + if (!iter->second.ptr) return nullptr; + return new VarHolder(iter->second.ptr); + } + return nullptr; +} + +void grad_del(VarHolder* x) { + auto iter = grad_backup.find(x->var->id); + if (iter != grad_backup.end()) + grad_backup.erase(iter); +} + +void backward(VarHolder* x) { + vector gnodes({x->var}); + bfs_backward(gnodes, [&](Node* node) { + if (node->is_stop_grad()) + return false; + return true; + }); + vector targets; + for (auto* node : gnodes) { + if (node->is_var() && node->flags.get(NodeFlags::_th_require_grad)) + targets.push_back(node->var()); + } + auto grads = grad(x->var, targets); + for (int i=0; im_doc = "Inner c++ core of jtorch"; + jittor::pyjt_def_all(m); +} +PYJT_MODULE_INIT(jtorch_core); diff --git a/python/jittor/compatibility/src/jtorch_core.h b/python/jittor/compatibility/src/jtorch_core.h new file mode 100644 index 00000000..36de6522 --- /dev/null +++ b/python/jittor/compatibility/src/jtorch_core.h @@ -0,0 +1,40 @@ +#pragma once +#include "common.h" +#include "var_holder.h" +#include "misc/fast_shared_ptr.h" + +namespace jittor { + +// @pyjt(device) +// @attrs(heaptype) +struct Device { + string name; + + // @pyjt(__init__) + Device(const string& name, int ordinal=0); + // @pyjt(__get__type, __str__) + inline string get_type() {return name;} + // @pyjt(__get__index) + inline int index() {return 0;} +}; + +// @pyjt(backward) +void backward(VarHolder* x); + +// @pyjt(grad_set) +void grad_set(VarHolder* x, Maybe v); +// @pyjt(grad_get) +Maybe grad_get(VarHolder* x); +// @pyjt(grad_del) +void grad_del(VarHolder* x); + +// @pyjt(retain_grad_set) +inline void retain_grad_set(VarHolder* x, bool v) { + x->var->flags.set(NodeFlags::_th_require_grad, v); +} +// @pyjt(retain_grad_get) +inline bool retain_grad_get(VarHolder* x) { + return x->var->flags.get(NodeFlags::_th_require_grad); +} + +} \ No newline at end of file diff --git a/python/jittor/compatibility/test/test_conflict_func.py b/python/jittor/compatibility/test/test_conflict_func.py new file mode 100644 index 00000000..97bd7d8f --- /dev/null +++ b/python/jittor/compatibility/test/test_conflict_func.py @@ -0,0 +1,25 @@ +import unittest +import numpy as np +import torch +import jittor as jt + +class TestConflictFunc(unittest.TestCase): + def test_max(self): + a = torch.Tensor([1,4,2]) + assert a.max() == 4 + v, k = a.max(dim=0) + assert v==4 and k==1 + + def test_argsort(self): + a = torch.Tensor([1,4,2]) + k = a.argsort() + assert jt.all_equal(k, [0,2,1]) + + with jt.flag_scope(th_mode=0): + k, v = a.argsort() + assert jt.all_equal(k, [0,2,1]) + + + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/compatibility/test/test_function.py b/python/jittor/compatibility/test/test_function.py new file mode 100644 index 00000000..9959dbae --- /dev/null +++ b/python/jittor/compatibility/test/test_function.py @@ -0,0 +1,58 @@ +import unittest +import numpy as np +import torch + +class TestFunction(unittest.TestCase): + def test_example1(self): + import jtorch + from jtorch import Function + + class MyFunc(Function): + @staticmethod + def forward(self, x, y): + self.x = x + self.y = y + return x*y, x/y + + @staticmethod + def backward(self, grad0, grad1): + return grad0 * self.y, grad1 * self.x + + a = jtorch.array(3.0) + a.requires_grad = True + b = jtorch.array(4.0) + b.requires_grad = True + func = MyFunc.apply + c,d = func(a, b) + (c+d*3).backward() + assert a.grad.data == 4 + assert b.grad.data == 9 + + def test_example2(self): + import jtorch as jt + from jtorch import Function + + class MyFunc(Function): + @staticmethod + def forward(self, x, y): + self.x = x + self.y = y + return x*y, x/y + + @staticmethod + def backward(self, grad0, grad1): + assert grad1 is None + return grad0 * self.y, None + a = jt.array(3.0) + a.requires_grad = True + b = jt.array(4.0) + b.requires_grad = True + func = MyFunc.apply + c,d = func(a, b) + d.stop_grad() + da, db = jt.grad(c+d*3, [a, b]) + assert da.data == 4 + assert db.data == 0 + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/compatibility/test/test_misc.py b/python/jittor/compatibility/test/test_misc.py new file mode 100644 index 00000000..00bf1b70 --- /dev/null +++ b/python/jittor/compatibility/test/test_misc.py @@ -0,0 +1,24 @@ +import unittest +import numpy as np +import torch + +class TestMisc(unittest.TestCase): + def test_update_grad(self): + class Net(torch.nn.Module): + def __init__(self): + super().__init__() + self.a = torch.nn.Parameter(torch.Tensor([1.0, 2.0])) + net = Net() + assert(net.a.requires_grad) + net.load_state_dict({"a": torch.Tensor([3.0, 4.0])}) + assert(net.a.requires_grad) + + def test_reshape(self): + a = torch.ones(3,3) + a.requires_grad = True + b = torch.reshape(a, [9]) + assert b.requires_grad == True + + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/compatibility/test/test_tutorial.py b/python/jittor/compatibility/test/test_tutorial.py new file mode 100644 index 00000000..92c087c7 --- /dev/null +++ b/python/jittor/compatibility/test/test_tutorial.py @@ -0,0 +1,56 @@ +import unittest +import numpy as np +import os +import subprocess as sp +import sys + +def check_two(cmd, parser=None, checker=None): + jtorch_out = sp.getoutput(cmd) + print("=========JTORCH OUT==========") + print(jtorch_out) + torch_out = sp.getoutput("PYTHONPATH= "+cmd) + print("=========TORCH OUT==========") + print(torch_out) + if parser: + torch_out = parser(torch_out) + jtorch_out = parser(jtorch_out) + if checker: + checker(torch_out, jtorch_out) + else: + assert torch_out == jtorch_out + return jtorch_out, torch_out + +jtorch_path = os.path.join(os.path.dirname(__file__), "..") +# come from https://pytorch.org/tutorials/beginner/pytorch_with_examples.html +class TestTutorial(unittest.TestCase): + def test_auto_grad1(self): + check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad1.py", + parser=lambda s: np.array(s.split())[[-10,-8,-5,-2]].astype(float), + checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-4)) + def test_auto_grad2(self): + check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad2.py", + parser=lambda s: np.array(s.split())[[-10,-8,-5,-2]].astype(float), + checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-4)) + def test_auto_grad3(self): + check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad3.py", + parser=lambda s: np.array(s.split())[[-9,-7,-4,-2]].astype(float), + checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-4)) + def test_auto_grad4(self): + check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad4.py", + parser=lambda s: np.array(s.split())[[-10,-8,-5,-2]].astype(float), + checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-4)) + def test_auto_grad5(self): + check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad5_optim.py", + parser=lambda s: np.array(s.split())[[-10,-8,-5,-2]].astype(float), + checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-2)) + def test_auto_grad6(self): + check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad6_module.py", + parser=lambda s: np.array(s.split())[[-10,-8,-5,-2]].astype(float), + checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-4)) + def test_auto_grad7(self): + check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad7_dynet.py", + parser=lambda s: np.array(s.split())[[-13,-10,-7,-3]].astype(float), + checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-2)) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/compatibility/tutorial/auto_grad1.py b/python/jittor/compatibility/tutorial/auto_grad1.py new file mode 100644 index 00000000..60a090ad --- /dev/null +++ b/python/jittor/compatibility/tutorial/auto_grad1.py @@ -0,0 +1,44 @@ +import torch +import math + +dtype = torch.float +device = torch.device("cpu") +# device = torch.device("cuda:0") # Uncomment this to run on GPU + +# Create random input and output data +x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype) +y = torch.sin(x) + +# Randomly initialize weights +a = torch.randn((), device=device, dtype=dtype) +b = torch.randn((), device=device, dtype=dtype) +c = torch.randn((), device=device, dtype=dtype) +d = torch.randn((), device=device, dtype=dtype) + +learning_rate = 1e-6 +for t in range(20000): + # Forward pass: compute predicted y + y_pred = a + b * x + c * x ** 2 + d * x ** 3 + + # Compute and print loss + loss = (y_pred - y).pow(2).sum().item() + if t % 1000 == 999: + print(t, loss) + + # Backprop to compute gradients of a, b, c, d with respect to loss + grad_y_pred = 2.0 * (y_pred - y) + grad_a = grad_y_pred.sum() + grad_b = (grad_y_pred * x).sum() + grad_c = (grad_y_pred * x ** 2).sum() + grad_d = (grad_y_pred * x ** 3).sum() + + # Update weights using gradient descent + a -= learning_rate * grad_a + b -= learning_rate * grad_b + c -= learning_rate * grad_c + d -= learning_rate * grad_d + # print(t, torch.liveness_info()) + # torch.sync_all() + + +print(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3') \ No newline at end of file diff --git a/python/jittor/compatibility/tutorial/auto_grad2.py b/python/jittor/compatibility/tutorial/auto_grad2.py new file mode 100644 index 00000000..a3bbc9a8 --- /dev/null +++ b/python/jittor/compatibility/tutorial/auto_grad2.py @@ -0,0 +1,60 @@ +# -*- coding: utf-8 -*- +import torch +import math + +dtype = torch.float +device = torch.device("cpu") +# device = torch.device("cuda:0") # Uncomment this to run on GPU + +# Create Tensors to hold input and outputs. +# By default, requires_grad=False, which indicates that we do not need to +# compute gradients with respect to these Tensors during the backward pass. +x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype) +y = torch.sin(x) + +# Create random Tensors for weights. For a third order polynomial, we need +# 4 weights: y = a + b x + c x^2 + d x^3 +# Setting requires_grad=True indicates that we want to compute gradients with +# respect to these Tensors during the backward pass. +a = torch.randn((), device=device, dtype=dtype, requires_grad=True) +b = torch.randn((), device=device, dtype=dtype, requires_grad=True) +c = torch.randn((), device=device, dtype=dtype, requires_grad=True) +d = torch.randn((), device=device, dtype=dtype, requires_grad=True) + +learning_rate = 1e-6 +for t in range(20000): + # Forward pass: compute predicted y using operations on Tensors. + y_pred = a + b * x + c * x ** 2 + d * x ** 3 + # print(y_pred.requires_grad) + # y_pred.requires_grad = False + + # Compute and print loss using operations on Tensors. + # Now loss is a Tensor of shape (1,) + # loss.item() gets the scalar value held in the loss. + loss = (y_pred - y).pow(2).sum() + if t % 1000 == 990: + print(t, loss.item()) + + # Use autograd to compute the backward pass. This call will compute the + # gradient of loss with respect to all Tensors with requires_grad=True. + # After this call a.grad, b.grad. c.grad and d.grad will be Tensors holding + # the gradient of the loss with respect to a, b, c, d respectively. + # torch.backward(loss) + loss.backward() + + # Manually update weights using gradient descent. Wrap in torch.no_grad() + # because weights have requires_grad=True, but we don't need to track this + # in autograd. + with torch.no_grad(): + a -= learning_rate * a.grad + b -= learning_rate * b.grad + c -= learning_rate * c.grad + d -= learning_rate * d.grad + + # Manually zero the gradients after updating weights + a.grad = None + b.grad = None + c.grad = None + d.grad = None + +print(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3') \ No newline at end of file diff --git a/python/jittor/compatibility/tutorial/auto_grad3.py b/python/jittor/compatibility/tutorial/auto_grad3.py new file mode 100644 index 00000000..654ec447 --- /dev/null +++ b/python/jittor/compatibility/tutorial/auto_grad3.py @@ -0,0 +1,85 @@ +# -*- coding: utf-8 -*- +import torch +import math + + +class LegendrePolynomial3(torch.autograd.Function): + """ + We can implement our own custom autograd Functions by subclassing + torch.autograd.Function and implementing the forward and backward passes + which operate on Tensors. + """ + + @staticmethod + def forward(ctx, input): + """ + In the forward pass we receive a Tensor containing the input and return + a Tensor containing the output. ctx is a context object that can be used + to stash information for backward computation. You can cache arbitrary + objects for use in the backward pass using the ctx.save_for_backward method. + """ + ctx.save_for_backward(input) + return 0.5 * (5 * input ** 3 - 3 * input) + + @staticmethod + def backward(ctx, grad_output): + """ + In the backward pass we receive a Tensor containing the gradient of the loss + with respect to the output, and we need to compute the gradient of the loss + with respect to the input. + """ + input, = ctx.saved_tensors + return grad_output * 1.5 * (5 * input ** 2 - 1) + + +dtype = torch.float +device = torch.device("cpu") +# device = torch.device("cuda:0") # Uncomment this to run on GPU + +# Create Tensors to hold input and outputs. +# By default, requires_grad=False, which indicates that we do not need to +# compute gradients with respect to these Tensors during the backward pass. +x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype) +y = torch.sin(x) + +# Create random Tensors for weights. For this example, we need +# 4 weights: y = a + b * P3(c + d * x), these weights need to be initialized +# not too far from the correct result to ensure convergence. +# Setting requires_grad=True indicates that we want to compute gradients with +# respect to these Tensors during the backward pass. +a = torch.full((), 0.0, device=device, dtype=dtype, requires_grad=True) +b = torch.full((), -1.0, device=device, dtype=dtype, requires_grad=True) +c = torch.full((), 0.0, device=device, dtype=dtype, requires_grad=True) +d = torch.full((), 0.3, device=device, dtype=dtype, requires_grad=True) + +learning_rate = 5e-6 +for t in range(2000): + # To apply our Function, we use Function.apply method. We alias this as 'P3'. + P3 = LegendrePolynomial3.apply + + # Forward pass: compute predicted y using operations; we compute + # P3 using our custom autograd operation. + y_pred = a + b * P3(c + d * x) + + # Compute and print loss + loss = (y_pred - y).pow(2).sum() + if t % 100 == 99: + print(t, loss.item()) + + # Use autograd to compute the backward pass. + loss.backward() + + # Update weights using gradient descent + with torch.no_grad(): + a -= learning_rate * a.grad + b -= learning_rate * b.grad + c -= learning_rate * c.grad + d -= learning_rate * d.grad + + # Manually zero the gradients after updating weights + a.grad = None + b.grad = None + c.grad = None + d.grad = None + +print(f'Result: y = {a.item()} + {b.item()} * P3( {c.item()} + {d.item()} x)') \ No newline at end of file diff --git a/python/jittor/compatibility/tutorial/auto_grad4.py b/python/jittor/compatibility/tutorial/auto_grad4.py new file mode 100644 index 00000000..062d0b0e --- /dev/null +++ b/python/jittor/compatibility/tutorial/auto_grad4.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- +import torch +import math + + +# Create Tensors to hold input and outputs. +x = torch.linspace(-math.pi, math.pi, 2000) +y = torch.sin(x) + +# For this example, the output y is a linear function of (x, x^2, x^3), so +# we can consider it as a linear layer neural network. Let's prepare the +# tensor (x, x^2, x^3). +p = torch.tensor([1, 2, 3]) +xx = x.unsqueeze(-1).pow(p) + +# In the above code, x.unsqueeze(-1) has shape (2000, 1), and p has shape +# (3,), for this case, broadcasting semantics will apply to obtain a tensor +# of shape (2000, 3) + +# Use the nn package to define our model as a sequence of layers. nn.Sequential +# is a Module which contains other Modules, and applies them in sequence to +# produce its output. The Linear Module computes output from input using a +# linear function, and holds internal Tensors for its weight and bias. +# The Flatten layer flatens the output of the linear layer to a 1D tensor, +# to match the shape of `y`. +model = torch.nn.Sequential( + torch.nn.Linear(3, 1), + torch.nn.Flatten(0, 1) +) + +# The nn package also contains definitions of popular loss functions; in this +# case we will use Mean Squared Error (MSE) as our loss function. +loss_fn = torch.nn.MSELoss(reduction='sum') +# print(model[0].weight.requires_grad) + +learning_rate = 1e-6 +for t in range(8000): + + # Forward pass: compute predicted y by passing x to the model. Module objects + # override the __call__ operator so you can call them like functions. When + # doing so you pass a Tensor of input data to the Module and it produces + # a Tensor of output data. + y_pred = model(xx) + + # Compute and print loss. We pass Tensors containing the predicted and true + # values of y, and the loss function returns a Tensor containing the + # loss. + loss = loss_fn(y_pred, y) + if t % 1000 == 999: + print(t, loss.item()) + + # Zero the gradients before running the backward pass. + model.zero_grad() + + # Backward pass: compute gradient of the loss with respect to all the learnable + # parameters of the model. Internally, the parameters of each Module are stored + # in Tensors with requires_grad=True, so this call will compute gradients for + # all learnable parameters in the model. + loss.backward() + + # Update the weights using gradient descent. Each parameter is a Tensor, so + # we can access its gradients like we did before. + with torch.no_grad(): + for param in model.parameters(): + param -= learning_rate * param.grad + +# You can access the first layer of `model` like accessing the first item of a list +linear_layer = model[0] + +# For linear layer, its parameters are stored as `weight` and `bias`. +print(f'Result: y = {linear_layer.bias.item()} + {linear_layer.weight[:, 0].item()} x + {linear_layer.weight[:, 1].item()} x^2 + {linear_layer.weight[:, 2].item()} x^3') \ No newline at end of file diff --git a/python/jittor/compatibility/tutorial/auto_grad5_optim.py b/python/jittor/compatibility/tutorial/auto_grad5_optim.py new file mode 100644 index 00000000..04949320 --- /dev/null +++ b/python/jittor/compatibility/tutorial/auto_grad5_optim.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +import torch +import math + + +# Create Tensors to hold input and outputs. +x = torch.linspace(-math.pi, math.pi, 2000) +y = torch.sin(x) + +# Prepare the input tensor (x, x^2, x^3). +p = torch.tensor([1, 2, 3]) +xx = x.unsqueeze(-1).pow(p) + +# Use the nn package to define our model and loss function. +model = torch.nn.Sequential( + torch.nn.Linear(3, 1), + torch.nn.Flatten(0, 1) +) +loss_fn = torch.nn.MSELoss(reduction='sum') + +# Use the optim package to define an Optimizer that will update the weights of +# the model for us. Here we will use RMSprop; the optim package contains many other +# optimization algorithms. The first argument to the RMSprop constructor tells the +# optimizer which Tensors it should update. +learning_rate = 1e-3 +optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate) +for t in range(8000): + # Forward pass: compute predicted y by passing x to the model. + y_pred = model(xx) + + # Compute and print loss. + loss = loss_fn(y_pred, y) + if t % 1000 == 999: + print(t, loss.item()) + + # Before the backward pass, use the optimizer object to zero all of the + # gradients for the variables it will update (which are the learnable + # weights of the model). This is because by default, gradients are + # accumulated in buffers( i.e, not overwritten) whenever .backward() + # is called. Checkout docs of torch.autograd.backward for more details. + optimizer.zero_grad() + + # Backward pass: compute gradient of the loss with respect to model + # parameters + loss.backward() + + # Calling the step function on an Optimizer makes an update to its + # parameters + optimizer.step() + + +linear_layer = model[0] +print(f'Result: y = {linear_layer.bias.item()} + {linear_layer.weight[:, 0].item()} x + {linear_layer.weight[:, 1].item()} x^2 + {linear_layer.weight[:, 2].item()} x^3') \ No newline at end of file diff --git a/python/jittor/compatibility/tutorial/auto_grad6_module.py b/python/jittor/compatibility/tutorial/auto_grad6_module.py new file mode 100644 index 00000000..a240e2b5 --- /dev/null +++ b/python/jittor/compatibility/tutorial/auto_grad6_module.py @@ -0,0 +1,59 @@ +# -*- coding: utf-8 -*- +import torch +import math + + +class Polynomial3(torch.nn.Module): + def __init__(self): + """ + In the constructor we instantiate four parameters and assign them as + member parameters. + """ + super().__init__() + self.a = torch.nn.Parameter(torch.randn(())) + self.b = torch.nn.Parameter(torch.randn(())) + self.c = torch.nn.Parameter(torch.randn(())) + self.d = torch.nn.Parameter(torch.randn(())) + + def forward(self, x): + """ + In the forward function we accept a Tensor of input data and we must return + a Tensor of output data. We can use Modules defined in the constructor as + well as arbitrary operators on Tensors. + """ + return self.a + self.b * x + self.c * x ** 2 + self.d * x ** 3 + + def string(self): + """ + Just like any class in Python, you can also define custom method on PyTorch modules + """ + return f'y = {self.a.item()} + {self.b.item()} x + {self.c.item()} x^2 + {self.d.item()} x^3' + + +# Create Tensors to hold input and outputs. +x = torch.linspace(-math.pi, math.pi, 2000) +y = torch.sin(x) + +# Construct our model by instantiating the class defined above +model = Polynomial3() + +# Construct our loss function and an Optimizer. The call to model.parameters() +# in the SGD constructor will contain the learnable parameters (defined +# with torch.nn.Parameter) which are members of the model. +criterion = torch.nn.MSELoss(reduction='sum') +optimizer = torch.optim.SGD(model.parameters(), lr=1e-6) +for t in range(8000): + # Forward pass: Compute predicted y by passing x to the model + y_pred = model(x) + + # Compute and print loss + loss = criterion(y_pred, y) + if t % 1000 == 999: + print(t, loss.item()) + + # Zero gradients, perform a backward pass, and update the weights. + optimizer.zero_grad() + loss.backward() + optimizer.step() + +print(f'Result: {model.string()}') \ No newline at end of file diff --git a/python/jittor/compatibility/tutorial/auto_grad7_dynet.py b/python/jittor/compatibility/tutorial/auto_grad7_dynet.py new file mode 100644 index 00000000..fa954771 --- /dev/null +++ b/python/jittor/compatibility/tutorial/auto_grad7_dynet.py @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- +import random +import torch +import math + + +class DynamicNet(torch.nn.Module): + def __init__(self): + """ + In the constructor we instantiate five parameters and assign them as members. + """ + super().__init__() + self.a = torch.nn.Parameter(torch.randn(())) + self.b = torch.nn.Parameter(torch.randn(())) + self.c = torch.nn.Parameter(torch.randn(())) + self.d = torch.nn.Parameter(torch.randn(())) + self.e = torch.nn.Parameter(torch.randn(())) + + def forward(self, x): + """ + For the forward pass of the model, we randomly choose either 4, 5 + and reuse the e parameter to compute the contribution of these orders. + + Since each forward pass builds a dynamic computation graph, we can use normal + Python control-flow operators like loops or conditional statements when + defining the forward pass of the model. + + Here we also see that it is perfectly safe to reuse the same parameter many + times when defining a computational graph. + """ + y = self.a + self.b * x + self.c * x ** 2 + self.d * x ** 3 + for exp in range(4, random.randint(4, 6)): + y = y + self.e * x ** exp + return y + + def string(self): + """ + Just like any class in Python, you can also define custom method on PyTorch modules + """ + return f'y = {self.a.item()} + {self.b.item()} x + {self.c.item()} x^2 + {self.d.item()} x^3 + {self.e.item()} x^4 ? + {self.e.item()} x^5 ?' + + +# Create Tensors to hold input and outputs. +x = torch.linspace(-math.pi, math.pi, 2000) +y = torch.sin(x) + +# Construct our model by instantiating the class defined above +model = DynamicNet() + +# Construct our loss function and an Optimizer. Training this strange model with +# vanilla stochastic gradient descent is tough, so we use momentum +criterion = torch.nn.MSELoss(reduction='sum') +optimizer = torch.optim.SGD(model.parameters(), lr=1e-8, momentum=0.9) +for t in range(60000): + # Forward pass: Compute predicted y by passing x to the model + y_pred = model(x) + + # Compute and print loss + loss = criterion(y_pred, y) + if t % 2000 == 1999: + print(t, loss.item()) + + # Zero gradients, perform a backward pass, and update the weights. + optimizer.zero_grad() + loss.backward() + optimizer.step() + # print(torch.liveness_info()) + +print(f'Result: {model.string()}') \ No newline at end of file diff --git a/python/jittor/compatibility/tutorial/quickstart.py b/python/jittor/compatibility/tutorial/quickstart.py new file mode 100644 index 00000000..f0401a9b --- /dev/null +++ b/python/jittor/compatibility/tutorial/quickstart.py @@ -0,0 +1,106 @@ +import torch +from torch import nn +# from jtorch.utils import DataLoader +from torch.utils.data import DataLoader +from torchvision import datasets +from torchvision.transforms import ToTensor + +# Download training data from open datasets. +training_data = datasets.FashionMNIST( + root="data", + train=True, + download=True, + transform=ToTensor(), +) + +# Download test data from open datasets. +test_data = datasets.FashionMNIST( + root="data", + train=False, + download=True, + transform=ToTensor(), +) + +batch_size = 64 + +# Create data loaders. +train_dataloader = DataLoader(training_data, batch_size=batch_size) +test_dataloader = DataLoader(test_data, batch_size=batch_size) + +print(len(train_dataloader)) +for X, y in test_dataloader: + print(f"Shape of X [N, C, H, W]: {X.shape}") + print(f"Shape of y: {y.shape} {y.dtype}") + break + +# Get cpu or gpu device for training. +device = "cuda" if torch.cuda.is_available() else "cpu" +print(f"Using {device} device") + +# Define model +class NeuralNetwork(nn.Module): + def __init__(self): + super(NeuralNetwork, self).__init__() + self.flatten = nn.Flatten() + self.linear_relu_stack = nn.Sequential( + nn.Linear(28*28, 512), + nn.ReLU(), + nn.Linear(512, 512), + nn.ReLU(), + nn.Linear(512, 10) + ) + + def forward(self, x): + x = self.flatten(x) + logits = self.linear_relu_stack(x) + return logits + +model = NeuralNetwork().to(device) +print(model) + + +loss_fn = nn.CrossEntropyLoss() +optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) + +def train(dataloader, model, loss_fn, optimizer): + size = len(dataloader.dataset) + model.train() + for batch, (X, y) in enumerate(dataloader): + X, y = X.to(device), y.to(device) + + # Compute prediction error + pred = model(X) + loss = loss_fn(pred, y) + + # Backpropagation + optimizer.zero_grad() + loss.backward() + optimizer.step() + + if batch % 100 == 0: + loss, current = loss.item(), batch * len(X) + print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]") + +def test(dataloader, model, loss_fn): + size = len(dataloader.dataset) + num_batches = len(dataloader) + model.eval() + test_loss, correct = 0, 0 + with torch.no_grad(): + for X, y in dataloader: + X, y = X.to(device), y.to(device) + pred = model(X) + test_loss += loss_fn(pred, y).item() + correct += (pred.argmax(1) == y).type(torch.float).sum().item() + test_loss /= num_batches + correct /= size + print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n") + + +epochs = 5 +test(test_dataloader, model, loss_fn) +for t in range(epochs): + print(f"Epoch {t+1}\n-------------------------------") + train(train_dataloader, model, loss_fn, optimizer) + test(test_dataloader, model, loss_fn) +print("Done!") \ No newline at end of file diff --git a/python/jittor/compatibility/utils/__init__.py b/python/jittor/compatibility/utils/__init__.py new file mode 100644 index 00000000..ac2c2bd8 --- /dev/null +++ b/python/jittor/compatibility/utils/__init__.py @@ -0,0 +1,5 @@ +cpp_extension = None +_flatten_dense_tensors = None +_unflatten_dense_tensors = None + +tensorboard = None \ No newline at end of file diff --git a/python/jittor/compatibility/utils/_pytree.py b/python/jittor/compatibility/utils/_pytree.py new file mode 100644 index 00000000..c3118964 --- /dev/null +++ b/python/jittor/compatibility/utils/_pytree.py @@ -0,0 +1,3 @@ +#TODO: Implement this +_register_pytree_node = None +_dict_flatten = None \ No newline at end of file diff --git a/python/jittor/compatibility/utils/checkpoint.py b/python/jittor/compatibility/utils/checkpoint.py new file mode 100644 index 00000000..ba3c3e8e --- /dev/null +++ b/python/jittor/compatibility/utils/checkpoint.py @@ -0,0 +1,8 @@ +detach_variable = None + + +def checkpoint( + *args, + **kwargs +): + pass diff --git a/python/jittor/compatibility/utils/data.py b/python/jittor/compatibility/utils/data.py new file mode 100644 index 00000000..5fcfcaa6 --- /dev/null +++ b/python/jittor/compatibility/utils/data.py @@ -0,0 +1,137 @@ +import jittor as jt +import jittor.dataset +from jittor.dataset import Dataset as JDataset + +from collections import namedtuple +from typing import Any, Callable, Iterable, Optional, Sequence, Union + + +class Dataset: + def __getitem__(self, index): + raise NotImplementedError + +class IterableDataset: + def __iter__(self): + raise NotImplementedError + + +class DataLoader(JDataset): + def __init__(self, dataset, + batch_size: Optional[int] = 1, + shuffle: Optional[bool] = False, + sampler = None, + batch_sampler = None, + num_workers: int = 0, + collate_fn = None, + pin_memory: bool = False, + drop_last: bool = False, + timeout: float = 0, + worker_init_fn = None, + multiprocessing_context=None, + generator=None, + *, prefetch_factor: int = 2, + persistent_workers: bool = False, + pin_memory_device: str = "") -> None: + super().__init__(batch_size=batch_size, + shuffle=shuffle, + num_workers=num_workers, + drop_last=drop_last) + + unsupported_kwargs = { + "batch_sampler": batch_sampler, + "pin_memory": pin_memory, + "timeout": timeout, + "worker_init_fn": worker_init_fn, + "multiprocessing_context": multiprocessing_context, + "generator": generator, + "persistent_workers": persistent_workers, + "pin_memory_device": pin_memory_device + } + for kwarg, value in unsupported_kwargs.items(): + if value: + jt.LOG.w(f"Not implemented Dataloader kwarg: {kwarg}") + + self.dataset = dataset + self.collate_fn = collate_fn + self.sampler = sampler + + if not isinstance(dataset, IterableDataset): + self.total_len = len(dataset) + else: + # TODO: support multiple worker for iterable dataset + assert(num_workers == 0) + + def collate_batch(self, batch): + if self.collate_fn is not None: + return self.collate_fn(batch) + else: + return super().collate_batch(batch) + + def __getitem__(self, i): + return self.dataset[i] + + def __iter__(self): + if isinstance(self.dataset, IterableDataset): + return self.inner_iter() + else: + return super().__iter__() + + def inner_iter(self): + current_batch = [] + + if jt.world_size > 1: + assert self.batch_size % jt.world_size == 0, \ + f"IterableDataset does not support a batch size ({self.batch_size}) that is not evenly divisible by the number of processes f{jt.world_size}" + real_batch_size = int(self.batch_size / jt.world_size) + else: + real_batch_size = self.batch_size + + for element in self.dataset: + current_batch.append(element) + + if len(current_batch) == real_batch_size: + current_batch = self.collate_batch(current_batch) + current_batch = self.to_jittor(current_batch) + yield current_batch + current_batch = [] + + if not self.drop_last and len(current_batch) > 0: + current_batch = self.collate_batch(current_batch) + yield self.to_jittor(current_batch) + +# def get_worker_info(): +# # always return the fake worker info +# return namedtuple('WorkerInfo', 'id num_workers')(0, 1) + +# class RandomSampler(jt.dataset.RandomSampler): +# def __init__(self, dataset, generator=None, **kwargs): +# super().__init__(dataset, **kwargs) + +# def __iter__(self): +# if getattr(self.dataset, "support_random_access", True): +# return super().__iter__() +# else: +# self.dataset.shuffle() +# return iter(range(self.dataset.__real_len__() if hasattr(self.dataset,"__real_len__") else self.dataset.__len__())) + +# class DistributedSampler(jt.dataset.Sampler): +# def __init__(self, sampler: RandomSampler): +# assert(isinstance(sampler, RandomSampler)) +# self.sampler = sampler + +# def set_epoch(self, epoch: int): +# ### do nothing, let jittor's inner dataset handle +# pass + +# def __iter__(self): +# return self.sampler.__iter__() + +# def __len__(self): +# return self.sampler.__len__() + +# BatchSampler = jt.dataset.BatchSampler +# Sampler = jt.dataset.Sampler +# SequentialSampler = jt.dataset.SequentialSampler +# SubsetRandomSampler = jt.dataset.SubsetRandomSampler + +# TensorDataset = Dataset diff --git a/python/jittor/compatibility/utils/dtype.py b/python/jittor/compatibility/utils/dtype.py new file mode 100644 index 00000000..41728383 --- /dev/null +++ b/python/jittor/compatibility/utils/dtype.py @@ -0,0 +1,9 @@ +from typing import Callable, Union +Dtype = Union[Callable, str] + +def get_string_dtype(dtype): + if callable(dtype): + dtype = dtype.__name__ + if not isinstance(dtype, str): + raise ValueError(f"dtype is expected to be str, python type function, or jittor type function, but got {dtype}.") + return dtype \ No newline at end of file diff --git a/python/jittor/compatibility/utils/hooks.py b/python/jittor/compatibility/utils/hooks.py new file mode 100644 index 00000000..e69de29b diff --git a/python/jittor/compatibility/utils/pip_publish.py b/python/jittor/compatibility/utils/pip_publish.py new file mode 100644 index 00000000..72ff245f --- /dev/null +++ b/python/jittor/compatibility/utils/pip_publish.py @@ -0,0 +1,34 @@ +import os +import glob +import shutil +import sys + +home_path = os.path.join(os.path.dirname(__file__), "..", "..", "..") +home_path = os.path.abspath(home_path) + +def callback(func, path, exc_info): + print(f"remove \"{path}\" failed.") + +def rmtree(path): + if os.path.isdir(path): + print(f"remove \"{path}\" recursive.") + shutil.rmtree(path, onerror=callback) + +def remove_tmpfile(): + dist_file = home_path+"/dist" + egg_file = glob.glob(home_path+"/**/*egg-info") + rmtree(dist_file) + for e in egg_file: + rmtree(e) + +def run_cmd(cmd): + print("[CMD]", cmd) + assert os.system(cmd)==0 + +os.chdir(home_path) +remove_tmpfile() + +run_cmd(f"{sys.executable} ./setup.py sdist") +run_cmd(f"{sys.executable} -m twine upload dist/*") + +remove_tmpfile() \ No newline at end of file diff --git a/python/jittor/compatibility/vision/_internally_replaced_utils.py b/python/jittor/compatibility/vision/_internally_replaced_utils.py new file mode 100644 index 00000000..748fa2ea --- /dev/null +++ b/python/jittor/compatibility/vision/_internally_replaced_utils.py @@ -0,0 +1,46 @@ +import importlib.machinery +import os + + +def _download_file_from_remote_location(fpath: str, url: str) -> None: + pass + + +def _is_remote_location_available() -> bool: + return False + + +def _get_extension_path(lib_name): + + lib_dir = os.path.dirname(__file__) + if os.name == "nt": + # Register the main torchvision library location on the default DLL path + import ctypes + import sys + + kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True) + with_load_library_flags = hasattr(kernel32, "AddDllDirectory") + prev_error_mode = kernel32.SetErrorMode(0x0001) + + if with_load_library_flags: + kernel32.AddDllDirectory.restype = ctypes.c_void_p + + if sys.version_info >= (3, 8): + os.add_dll_directory(lib_dir) + elif with_load_library_flags: + res = kernel32.AddDllDirectory(lib_dir) + if res is None: + err = ctypes.WinError(ctypes.get_last_error()) + err.strerror += f' Error adding "{lib_dir}" to the DLL directories.' + raise err + + kernel32.SetErrorMode(prev_error_mode) + + loader_details = (importlib.machinery.ExtensionFileLoader, importlib.machinery.EXTENSION_SUFFIXES) + + extfinder = importlib.machinery.FileFinder(lib_dir, loader_details) + ext_specs = extfinder.find_spec(lib_name) + if ext_specs is None: + raise ImportError + + return ext_specs.origin \ No newline at end of file diff --git a/python/jittor/compatibility/vision/datasets/__init__.py b/python/jittor/compatibility/vision/datasets/__init__.py new file mode 100644 index 00000000..d04187f1 --- /dev/null +++ b/python/jittor/compatibility/vision/datasets/__init__.py @@ -0,0 +1,9 @@ +from .mnist import EMNIST, FashionMNIST, KMNIST, MNIST, QMNIST + +__all__ = ( + "EMNIST", + "FashionMNIST", + "QMNIST", + "MNIST", + "KMNIST", +) \ No newline at end of file diff --git a/python/jittor/compatibility/vision/datasets/mnist.py b/python/jittor/compatibility/vision/datasets/mnist.py new file mode 100644 index 00000000..dfc3787b --- /dev/null +++ b/python/jittor/compatibility/vision/datasets/mnist.py @@ -0,0 +1,558 @@ +import codecs +import os +import os.path +import shutil +import string +import sys +import warnings +from typing import Any, Callable, Dict, List, Optional, Tuple +from urllib.error import URLError + +import numpy as np +import torch +from PIL import Image + +from .utils import check_integrity, download_and_extract_archive, extract_archive, verify_str_arg +from .vision import VisionDataset + + +class MNIST(VisionDataset): + """`MNIST `_ Dataset. + + Args: + root (string): Root directory of dataset where ``MNIST/raw/train-images-idx3-ubyte`` + and ``MNIST/raw/t10k-images-idx3-ubyte`` exist. + train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``, + otherwise from ``t10k-images-idx3-ubyte``. + download (bool, optional): If True, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + """ + + mirrors = [ + "http://yann.lecun.com/exdb/mnist/", + "https://ossci-datasets.s3.amazonaws.com/mnist/", + ] + + resources = [ + ("train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"), + ("train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"), + ("t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"), + ("t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c"), + ] + + training_file = "training.pt" + test_file = "test.pt" + classes = [ + "0 - zero", + "1 - one", + "2 - two", + "3 - three", + "4 - four", + "5 - five", + "6 - six", + "7 - seven", + "8 - eight", + "9 - nine", + ] + + @property + def train_labels(self): + warnings.warn("train_labels has been renamed targets") + return self.targets + + @property + def test_labels(self): + warnings.warn("test_labels has been renamed targets") + return self.targets + + @property + def train_data(self): + warnings.warn("train_data has been renamed data") + return self.data + + @property + def test_data(self): + warnings.warn("test_data has been renamed data") + return self.data + + def __init__( + self, + root: str, + train: bool = True, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, + ) -> None: + super().__init__(root, transform=transform, target_transform=target_transform) + self.train = train # training set or test set + + if self._check_legacy_exist(): + self.data, self.targets = self._load_legacy_data() + return + + if download: + self.download() + + if not self._check_exists(): + raise RuntimeError("Dataset not found. You can use download=True to download it") + + self.data, self.targets = self._load_data() + + def _check_legacy_exist(self): + processed_folder_exists = os.path.exists(self.processed_folder) + if not processed_folder_exists: + return False + + return all( + check_integrity(os.path.join(self.processed_folder, file)) for file in (self.training_file, self.test_file) + ) + + def _load_legacy_data(self): + # This is for BC only. We no longer cache the data in a custom binary, but simply read from the raw data + # directly. + data_file = self.training_file if self.train else self.test_file + return torch.load(os.path.join(self.processed_folder, data_file)) + + def _load_data(self): + image_file = f"{'train' if self.train else 't10k'}-images-idx3-ubyte" + data = read_image_file(os.path.join(self.raw_folder, image_file)) + + label_file = f"{'train' if self.train else 't10k'}-labels-idx1-ubyte" + targets = read_label_file(os.path.join(self.raw_folder, label_file)) + + return data, targets + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + """ + Args: + index (int): Index + + Returns: + tuple: (image, target) where target is index of the target class. + """ + img, target = self.data[index], int(self.targets[index]) + + # doing this so that it is consistent with all other datasets + # to return a PIL Image + img = Image.fromarray(img.numpy(), mode="L") + + if self.transform is not None: + img = self.transform(img) + + if self.target_transform is not None: + target = self.target_transform(target) + + return img, target + + def __len__(self) -> int: + return len(self.data) + + @property + def raw_folder(self) -> str: + return os.path.join(self.root, self.__class__.__name__, "raw") + + @property + def processed_folder(self) -> str: + return os.path.join(self.root, self.__class__.__name__, "processed") + + @property + def class_to_idx(self) -> Dict[str, int]: + return {_class: i for i, _class in enumerate(self.classes)} + + def _check_exists(self) -> bool: + return all( + check_integrity(os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0])) + for url, _ in self.resources + ) + + def download(self) -> None: + """Download the MNIST data if it doesn't exist already.""" + + if self._check_exists(): + return + + os.makedirs(self.raw_folder, exist_ok=True) + + # download files + for filename, md5 in self.resources: + for mirror in self.mirrors: + url = f"{mirror}{filename}" + try: + print(f"Downloading {url}") + download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5) + except URLError as error: + print(f"Failed to download (trying next):\n{error}") + continue + finally: + print() + break + else: + raise RuntimeError(f"Error downloading {filename}") + + def extra_repr(self) -> str: + split = "Train" if self.train is True else "Test" + return f"Split: {split}" + + +class FashionMNIST(MNIST): + """`Fashion-MNIST `_ Dataset. + + Args: + root (string): Root directory of dataset where ``FashionMNIST/raw/train-images-idx3-ubyte`` + and ``FashionMNIST/raw/t10k-images-idx3-ubyte`` exist. + train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``, + otherwise from ``t10k-images-idx3-ubyte``. + download (bool, optional): If True, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + """ + + mirrors = ["http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/"] + + resources = [ + ("train-images-idx3-ubyte.gz", "8d4fb7e6c68d591d4c3dfef9ec88bf0d"), + ("train-labels-idx1-ubyte.gz", "25c81989df183df01b3e8a0aad5dffbe"), + ("t10k-images-idx3-ubyte.gz", "bef4ecab320f06d8554ea6380940ec79"), + ("t10k-labels-idx1-ubyte.gz", "bb300cfdad3c16e7a12a480ee83cd310"), + ] + classes = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"] + + +class KMNIST(MNIST): + """`Kuzushiji-MNIST `_ Dataset. + + Args: + root (string): Root directory of dataset where ``KMNIST/raw/train-images-idx3-ubyte`` + and ``KMNIST/raw/t10k-images-idx3-ubyte`` exist. + train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``, + otherwise from ``t10k-images-idx3-ubyte``. + download (bool, optional): If True, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + """ + + mirrors = ["http://codh.rois.ac.jp/kmnist/dataset/kmnist/"] + + resources = [ + ("train-images-idx3-ubyte.gz", "bdb82020997e1d708af4cf47b453dcf7"), + ("train-labels-idx1-ubyte.gz", "e144d726b3acfaa3e44228e80efcd344"), + ("t10k-images-idx3-ubyte.gz", "5c965bf0a639b31b8f53240b1b52f4d7"), + ("t10k-labels-idx1-ubyte.gz", "7320c461ea6c1c855c0b718fb2a4b134"), + ] + classes = ["o", "ki", "su", "tsu", "na", "ha", "ma", "ya", "re", "wo"] + + +class EMNIST(MNIST): + """`EMNIST `_ Dataset. + + Args: + root (string): Root directory of dataset where ``EMNIST/raw/train-images-idx3-ubyte`` + and ``EMNIST/raw/t10k-images-idx3-ubyte`` exist. + split (string): The dataset has 6 different splits: ``byclass``, ``bymerge``, + ``balanced``, ``letters``, ``digits`` and ``mnist``. This argument specifies + which one to use. + train (bool, optional): If True, creates dataset from ``training.pt``, + otherwise from ``test.pt``. + download (bool, optional): If True, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + """ + + url = "https://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zip" + md5 = "58c8d27c78d21e728a6bc7b3cc06412e" + splits = ("byclass", "bymerge", "balanced", "letters", "digits", "mnist") + # Merged Classes assumes Same structure for both uppercase and lowercase version + _merged_classes = {"c", "i", "j", "k", "l", "m", "o", "p", "s", "u", "v", "w", "x", "y", "z"} + _all_classes = set(string.digits + string.ascii_letters) + classes_split_dict = { + "byclass": sorted(list(_all_classes)), + "bymerge": sorted(list(_all_classes - _merged_classes)), + "balanced": sorted(list(_all_classes - _merged_classes)), + "letters": ["N/A"] + list(string.ascii_lowercase), + "digits": list(string.digits), + "mnist": list(string.digits), + } + + def __init__(self, root: str, split: str, **kwargs: Any) -> None: + self.split = verify_str_arg(split, "split", self.splits) + self.training_file = self._training_file(split) + self.test_file = self._test_file(split) + super().__init__(root, **kwargs) + self.classes = self.classes_split_dict[self.split] + + @staticmethod + def _training_file(split) -> str: + return f"training_{split}.pt" + + @staticmethod + def _test_file(split) -> str: + return f"test_{split}.pt" + + @property + def _file_prefix(self) -> str: + return f"emnist-{self.split}-{'train' if self.train else 'test'}" + + @property + def images_file(self) -> str: + return os.path.join(self.raw_folder, f"{self._file_prefix}-images-idx3-ubyte") + + @property + def labels_file(self) -> str: + return os.path.join(self.raw_folder, f"{self._file_prefix}-labels-idx1-ubyte") + + def _load_data(self): + return read_image_file(self.images_file), read_label_file(self.labels_file) + + def _check_exists(self) -> bool: + return all(check_integrity(file) for file in (self.images_file, self.labels_file)) + + def download(self) -> None: + """Download the EMNIST data if it doesn't exist already.""" + + if self._check_exists(): + return + + os.makedirs(self.raw_folder, exist_ok=True) + + download_and_extract_archive(self.url, download_root=self.raw_folder, md5=self.md5) + gzip_folder = os.path.join(self.raw_folder, "gzip") + for gzip_file in os.listdir(gzip_folder): + if gzip_file.endswith(".gz"): + extract_archive(os.path.join(gzip_folder, gzip_file), self.raw_folder) + shutil.rmtree(gzip_folder) + + +class QMNIST(MNIST): + """`QMNIST `_ Dataset. + + Args: + root (string): Root directory of dataset whose ``raw`` + subdir contains binary files of the datasets. + what (string,optional): Can be 'train', 'test', 'test10k', + 'test50k', or 'nist' for respectively the mnist compatible + training set, the 60k qmnist testing set, the 10k qmnist + examples that match the mnist testing set, the 50k + remaining qmnist testing examples, or all the nist + digits. The default is to select 'train' or 'test' + according to the compatibility argument 'train'. + compat (bool,optional): A boolean that says whether the target + for each example is class number (for compatibility with + the MNIST dataloader) or a torch vector containing the + full qmnist information. Default=True. + download (bool, optional): If True, downloads the dataset from + the internet and puts it in root directory. If dataset is + already downloaded, it is not downloaded again. + transform (callable, optional): A function/transform that + takes in an PIL image and returns a transformed + version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform + that takes in the target and transforms it. + train (bool,optional,compatibility): When argument 'what' is + not specified, this boolean decides whether to load the + training set ot the testing set. Default: True. + """ + + subsets = {"train": "train", "test": "test", "test10k": "test", "test50k": "test", "nist": "nist"} + resources: Dict[str, List[Tuple[str, str]]] = { # type: ignore[assignment] + "train": [ + ( + "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-images-idx3-ubyte.gz", + "ed72d4157d28c017586c42bc6afe6370", + ), + ( + "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-labels-idx2-int.gz", + "0058f8dd561b90ffdd0f734c6a30e5e4", + ), + ], + "test": [ + ( + "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-images-idx3-ubyte.gz", + "1394631089c404de565df7b7aeaf9412", + ), + ( + "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-labels-idx2-int.gz", + "5b5b05890a5e13444e108efe57b788aa", + ), + ], + "nist": [ + ( + "https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-images-idx3-ubyte.xz", + "7f124b3b8ab81486c9d8c2749c17f834", + ), + ( + "https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-labels-idx2-int.xz", + "5ed0e788978e45d4a8bd4b7caec3d79d", + ), + ], + } + classes = [ + "0 - zero", + "1 - one", + "2 - two", + "3 - three", + "4 - four", + "5 - five", + "6 - six", + "7 - seven", + "8 - eight", + "9 - nine", + ] + + def __init__( + self, root: str, what: Optional[str] = None, compat: bool = True, train: bool = True, **kwargs: Any + ) -> None: + if what is None: + what = "train" if train else "test" + self.what = verify_str_arg(what, "what", tuple(self.subsets.keys())) + self.compat = compat + self.data_file = what + ".pt" + self.training_file = self.data_file + self.test_file = self.data_file + super().__init__(root, train, **kwargs) + + @property + def images_file(self) -> str: + (url, _), _ = self.resources[self.subsets[self.what]] + return os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0]) + + @property + def labels_file(self) -> str: + _, (url, _) = self.resources[self.subsets[self.what]] + return os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0]) + + def _check_exists(self) -> bool: + return all(check_integrity(file) for file in (self.images_file, self.labels_file)) + + def _load_data(self): + data = read_sn3_pascalvincent_tensor(self.images_file) + if data.dtype != torch.uint8: + raise TypeError(f"data should be of dtype torch.uint8 instead of {data.dtype}") + if data.ndimension() != 3: + raise ValueError("data should have 3 dimensions instead of {data.ndimension()}") + + targets = read_sn3_pascalvincent_tensor(self.labels_file).long() + if targets.ndimension() != 2: + raise ValueError(f"targets should have 2 dimensions instead of {targets.ndimension()}") + + if self.what == "test10k": + data = data[0:10000, :, :].clone() + targets = targets[0:10000, :].clone() + elif self.what == "test50k": + data = data[10000:, :, :].clone() + targets = targets[10000:, :].clone() + + return data, targets + + def download(self) -> None: + """Download the QMNIST data if it doesn't exist already. + Note that we only download what has been asked for (argument 'what'). + """ + if self._check_exists(): + return + + os.makedirs(self.raw_folder, exist_ok=True) + split = self.resources[self.subsets[self.what]] + + for url, md5 in split: + download_and_extract_archive(url, self.raw_folder, md5=md5) + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + # redefined to handle the compat flag + img, target = self.data[index], self.targets[index] + img = Image.fromarray(img.numpy(), mode="L") + if self.transform is not None: + img = self.transform(img) + if self.compat: + target = int(target[0]) + if self.target_transform is not None: + target = self.target_transform(target) + return img, target + + def extra_repr(self) -> str: + return f"Split: {self.what}" + + +def get_int(b: bytes) -> int: + return int(codecs.encode(b, "hex"), 16) + + +SN3_PASCALVINCENT_BITSMAP = { + 8: torch.uint8, + 9: torch.int8, + 11: torch.int16, + 12: torch.int32, + 13: torch.float32, + 14: torch.float64, +} + +TORCH_TYPE_BITS = { + torch.uint8: 8, + torch.int8: 8, + torch.int16: 16, + torch.int32: 32, + torch.float32: 32, + torch.float64: 64, +} + + +def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tensor: + """Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh'). + Argument may be a filename, compressed filename, or file object. + """ + # read + with open(path, "rb") as f: + data = f.read() + # parse + magic = get_int(data[0:4]) + nd = magic % 256 + ty = magic // 256 + assert 1 <= nd <= 3 + assert 8 <= ty <= 14 + torch_type = SN3_PASCALVINCENT_BITSMAP[ty] + s = [get_int(data[4 * (i + 1) : 4 * (i + 2)]) for i in range(nd)] + + num_bytes_per_value = TORCH_TYPE_BITS[torch_type] // 8 + # The MNIST format uses the big endian byte order. If the system uses little endian byte order by default, + # we need to reverse the bytes before we can read them with torch.frombuffer(). + needs_byte_reversal = sys.byteorder == "little" and num_bytes_per_value > 1 + parsed = torch.frombuffer(bytearray(data), dtype=torch_type, offset=(4 * (nd + 1))) + if needs_byte_reversal: + parsed = parsed.flip(0) + + assert parsed.shape[0] == np.prod(s) or not strict + return parsed.view(*s) + + +def read_label_file(path: str) -> torch.Tensor: + x = read_sn3_pascalvincent_tensor(path, strict=False) + if x.dtype != torch.uint8: + raise TypeError(f"x should be of dtype torch.uint8 instead of {x.dtype}") + if x.ndimension() != 1: + raise ValueError(f"x should have 1 dimension instead of {x.ndimension()}") + return x.long() + + +def read_image_file(path: str) -> torch.Tensor: + x = read_sn3_pascalvincent_tensor(path, strict=False) + if x.dtype != torch.uint8: + raise TypeError(f"x should be of dtype torch.uint8 instead of {x.dtype}") + if x.ndimension() != 3: + raise ValueError(f"x should have 3 dimension instead of {x.ndimension()}") + return x \ No newline at end of file diff --git a/python/jittor/compatibility/vision/datasets/utils.py b/python/jittor/compatibility/vision/datasets/utils.py new file mode 100644 index 00000000..f9ae1a89 --- /dev/null +++ b/python/jittor/compatibility/vision/datasets/utils.py @@ -0,0 +1,522 @@ +import bz2 +import contextlib +import gzip +import hashlib +import itertools +import lzma +import os +import os.path +import pathlib +import re +import sys +import tarfile +import urllib +import urllib.error +import urllib.request +import warnings +import zipfile +from typing import Any, Callable, Dict, IO, Iterable, Iterator, List, Optional, Tuple, TypeVar +from urllib.parse import urlparse + +import numpy as np +import requests +import torch +from tqdm import tqdm + +from .._internally_replaced_utils import _download_file_from_remote_location, _is_remote_location_available + +USER_AGENT = "pytorch/vision" + + +def _save_response_content( + content: Iterator[bytes], + destination: str, + length: Optional[int] = None, +) -> None: + with open(destination, "wb") as fh, tqdm(total=length) as pbar: + for chunk in content: + # filter out keep-alive new chunks + if not chunk: + continue + + fh.write(chunk) + pbar.update(len(chunk)) + + +def _urlretrieve(url: str, filename: str, chunk_size: int = 1024 * 32) -> None: + with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response: + _save_response_content(iter(lambda: response.read(chunk_size), b""), filename, length=response.length) + + +def gen_bar_updater() -> Callable[[int, int, int], None]: + warnings.warn("The function `gen_bar_update` is deprecated since 0.13 and will be removed in 0.15.") + pbar = tqdm(total=None) + + def bar_update(count, block_size, total_size): + if pbar.total is None and total_size: + pbar.total = total_size + progress_bytes = count * block_size + pbar.update(progress_bytes - pbar.n) + + return bar_update + + +def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024) -> str: + # Setting the `usedforsecurity` flag does not change anything about the functionality, but indicates that we are + # not using the MD5 checksum for cryptography. This enables its usage in restricted environments like FIPS. Without + # it torchvision.datasets is unusable in these environments since we perform a MD5 check everywhere. + if sys.version_info >= (3, 9): + md5 = hashlib.md5(usedforsecurity=False) + else: + md5 = hashlib.md5() + with open(fpath, "rb") as f: + for chunk in iter(lambda: f.read(chunk_size), b""): + md5.update(chunk) + return md5.hexdigest() + + +def check_md5(fpath: str, md5: str, **kwargs: Any) -> bool: + return md5 == calculate_md5(fpath, **kwargs) + + +def check_integrity(fpath: str, md5: Optional[str] = None) -> bool: + if not os.path.isfile(fpath): + return False + if md5 is None: + return True + return check_md5(fpath, md5) + + +def _get_redirect_url(url: str, max_hops: int = 3) -> str: + initial_url = url + headers = {"Method": "HEAD", "User-Agent": USER_AGENT} + + for _ in range(max_hops + 1): + with urllib.request.urlopen(urllib.request.Request(url, headers=headers)) as response: + if response.url == url or response.url is None: + return url + + url = response.url + else: + raise RecursionError( + f"Request to {initial_url} exceeded {max_hops} redirects. The last redirect points to {url}." + ) + + +def _get_google_drive_file_id(url: str) -> Optional[str]: + parts = urlparse(url) + + if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None: + return None + + match = re.match(r"/file/d/(?P[^/]*)", parts.path) + if match is None: + return None + + return match.group("id") + + +def download_url( + url: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None, max_redirect_hops: int = 3 +) -> None: + """Download a file from a url and place it in root. + + Args: + url (str): URL to download file from + root (str): Directory to place downloaded file in + filename (str, optional): Name to save the file under. If None, use the basename of the URL + md5 (str, optional): MD5 checksum of the download. If None, do not check + max_redirect_hops (int, optional): Maximum number of redirect hops allowed + """ + root = os.path.expanduser(root) + if not filename: + filename = os.path.basename(url) + fpath = os.path.join(root, filename) + + os.makedirs(root, exist_ok=True) + + # check if file is already present locally + if check_integrity(fpath, md5): + print("Using downloaded and verified file: " + fpath) + return + + if _is_remote_location_available(): + _download_file_from_remote_location(fpath, url) + else: + # expand redirect chain if needed + url = _get_redirect_url(url, max_hops=max_redirect_hops) + + # check if file is located on Google Drive + file_id = _get_google_drive_file_id(url) + if file_id is not None: + return download_file_from_google_drive(file_id, root, filename, md5) + + # download the file + try: + print("Downloading " + url + " to " + fpath) + _urlretrieve(url, fpath) + except (urllib.error.URLError, OSError) as e: # type: ignore[attr-defined] + if url[:5] == "https": + url = url.replace("https:", "http:") + print("Failed download. Trying https -> http instead. Downloading " + url + " to " + fpath) + _urlretrieve(url, fpath) + else: + raise e + + # check integrity of downloaded file + if not check_integrity(fpath, md5): + raise RuntimeError("File not found or corrupted.") + + +def list_dir(root: str, prefix: bool = False) -> List[str]: + """List all directories at a given root + + Args: + root (str): Path to directory whose folders need to be listed + prefix (bool, optional): If true, prepends the path to each result, otherwise + only returns the name of the directories found + """ + root = os.path.expanduser(root) + directories = [p for p in os.listdir(root) if os.path.isdir(os.path.join(root, p))] + if prefix is True: + directories = [os.path.join(root, d) for d in directories] + return directories + + +def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]: + """List all files ending with a suffix at a given root + + Args: + root (str): Path to directory whose folders need to be listed + suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png'). + It uses the Python "str.endswith" method and is passed directly + prefix (bool, optional): If true, prepends the path to each result, otherwise + only returns the name of the files found + """ + root = os.path.expanduser(root) + files = [p for p in os.listdir(root) if os.path.isfile(os.path.join(root, p)) and p.endswith(suffix)] + if prefix is True: + files = [os.path.join(root, d) for d in files] + return files + + +def _extract_gdrive_api_response(response, chunk_size: int = 32 * 1024) -> Tuple[bytes, Iterator[bytes]]: + content = response.iter_content(chunk_size) + first_chunk = None + # filter out keep-alive new chunks + while not first_chunk: + first_chunk = next(content) + content = itertools.chain([first_chunk], content) + + try: + match = re.search("Google Drive - (?P<api_response>.+?)", first_chunk.decode()) + api_response = match["api_response"] if match is not None else None + except UnicodeDecodeError: + api_response = None + return api_response, content + + +def download_file_from_google_drive(file_id: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None): + """Download a Google Drive file from and place it in root. + + Args: + file_id (str): id of file to be downloaded + root (str): Directory to place downloaded file in + filename (str, optional): Name to save the file under. If None, use the id of the file. + md5 (str, optional): MD5 checksum of the download. If None, do not check + """ + # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url + + root = os.path.expanduser(root) + if not filename: + filename = file_id + fpath = os.path.join(root, filename) + + os.makedirs(root, exist_ok=True) + + if check_integrity(fpath, md5): + print(f"Using downloaded {'and verified ' if md5 else ''}file: {fpath}") + return + + url = "https://drive.google.com/uc" + params = dict(id=file_id, export="download") + with requests.Session() as session: + response = session.get(url, params=params, stream=True) + + for key, value in response.cookies.items(): + if key.startswith("download_warning"): + token = value + break + else: + api_response, content = _extract_gdrive_api_response(response) + token = "t" if api_response == "Virus scan warning" else None + + if token is not None: + response = session.get(url, params=dict(params, confirm=token), stream=True) + api_response, content = _extract_gdrive_api_response(response) + + if api_response == "Quota exceeded": + raise RuntimeError( + f"The daily quota of the file {filename} is exceeded and it " + f"can't be downloaded. This is a limitation of Google Drive " + f"and can only be overcome by trying again later." + ) + + _save_response_content(content, fpath) + + # In case we deal with an unhandled GDrive API response, the file should be smaller than 10kB and contain only text + if os.stat(fpath).st_size < 10 * 1024: + with contextlib.suppress(UnicodeDecodeError), open(fpath) as fh: + text = fh.read() + # Regular expression to detect HTML. Copied from https://stackoverflow.com/a/70585604 + if re.search(r"]*\s*>|(&(?:[\w\d]+|#\d+|#x[a-f\d]+);)", text): + warnings.warn( + f"We detected some HTML elements in the downloaded file. " + f"This most likely means that the download triggered an unhandled API response by GDrive. " + f"Please report this to torchvision at https://github.com/pytorch/vision/issues including " + f"the response:\n\n{text}" + ) + + if md5 and not check_md5(fpath, md5): + raise RuntimeError( + f"The MD5 checksum of the download file {fpath} does not match the one on record." + f"Please delete the file and try again. " + f"If the issue persists, please report this to torchvision at https://github.com/pytorch/vision/issues." + ) + + +def _extract_tar(from_path: str, to_path: str, compression: Optional[str]) -> None: + with tarfile.open(from_path, f"r:{compression[1:]}" if compression else "r") as tar: + tar.extractall(to_path) + + +_ZIP_COMPRESSION_MAP: Dict[str, int] = { + ".bz2": zipfile.ZIP_BZIP2, + ".xz": zipfile.ZIP_LZMA, +} + + +def _extract_zip(from_path: str, to_path: str, compression: Optional[str]) -> None: + with zipfile.ZipFile( + from_path, "r", compression=_ZIP_COMPRESSION_MAP[compression] if compression else zipfile.ZIP_STORED + ) as zip: + zip.extractall(to_path) + + +_ARCHIVE_EXTRACTORS: Dict[str, Callable[[str, str, Optional[str]], None]] = { + ".tar": _extract_tar, + ".zip": _extract_zip, +} +_COMPRESSED_FILE_OPENERS: Dict[str, Callable[..., IO]] = { + ".bz2": bz2.open, + ".gz": gzip.open, + ".xz": lzma.open, +} +_FILE_TYPE_ALIASES: Dict[str, Tuple[Optional[str], Optional[str]]] = { + ".tbz": (".tar", ".bz2"), + ".tbz2": (".tar", ".bz2"), + ".tgz": (".tar", ".gz"), +} + + +def _detect_file_type(file: str) -> Tuple[str, Optional[str], Optional[str]]: + """Detect the archive type and/or compression of a file. + + Args: + file (str): the filename + + Returns: + (tuple): tuple of suffix, archive type, and compression + + Raises: + RuntimeError: if file has no suffix or suffix is not supported + """ + suffixes = pathlib.Path(file).suffixes + if not suffixes: + raise RuntimeError( + f"File '{file}' has no suffixes that could be used to detect the archive type and compression." + ) + suffix = suffixes[-1] + + # check if the suffix is a known alias + if suffix in _FILE_TYPE_ALIASES: + return (suffix, *_FILE_TYPE_ALIASES[suffix]) + + # check if the suffix is an archive type + if suffix in _ARCHIVE_EXTRACTORS: + return suffix, suffix, None + + # check if the suffix is a compression + if suffix in _COMPRESSED_FILE_OPENERS: + # check for suffix hierarchy + if len(suffixes) > 1: + suffix2 = suffixes[-2] + + # check if the suffix2 is an archive type + if suffix2 in _ARCHIVE_EXTRACTORS: + return suffix2 + suffix, suffix2, suffix + + return suffix, None, suffix + + valid_suffixes = sorted(set(_FILE_TYPE_ALIASES) | set(_ARCHIVE_EXTRACTORS) | set(_COMPRESSED_FILE_OPENERS)) + raise RuntimeError(f"Unknown compression or archive type: '{suffix}'.\nKnown suffixes are: '{valid_suffixes}'.") + + +def _decompress(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str: + r"""Decompress a file. + + The compression is automatically detected from the file name. + + Args: + from_path (str): Path to the file to be decompressed. + to_path (str): Path to the decompressed file. If omitted, ``from_path`` without compression extension is used. + remove_finished (bool): If ``True``, remove the file after the extraction. + + Returns: + (str): Path to the decompressed file. + """ + suffix, archive_type, compression = _detect_file_type(from_path) + if not compression: + raise RuntimeError(f"Couldn't detect a compression from suffix {suffix}.") + + if to_path is None: + to_path = from_path.replace(suffix, archive_type if archive_type is not None else "") + + # We don't need to check for a missing key here, since this was already done in _detect_file_type() + compressed_file_opener = _COMPRESSED_FILE_OPENERS[compression] + + with compressed_file_opener(from_path, "rb") as rfh, open(to_path, "wb") as wfh: + wfh.write(rfh.read()) + + if remove_finished: + os.remove(from_path) + + return to_path + + +def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str: + """Extract an archive. + + The archive type and a possible compression is automatically detected from the file name. If the file is compressed + but not an archive the call is dispatched to :func:`decompress`. + + Args: + from_path (str): Path to the file to be extracted. + to_path (str): Path to the directory the file will be extracted to. If omitted, the directory of the file is + used. + remove_finished (bool): If ``True``, remove the file after the extraction. + + Returns: + (str): Path to the directory the file was extracted to. + """ + if to_path is None: + to_path = os.path.dirname(from_path) + + suffix, archive_type, compression = _detect_file_type(from_path) + if not archive_type: + return _decompress( + from_path, + os.path.join(to_path, os.path.basename(from_path).replace(suffix, "")), + remove_finished=remove_finished, + ) + + # We don't need to check for a missing key here, since this was already done in _detect_file_type() + extractor = _ARCHIVE_EXTRACTORS[archive_type] + + extractor(from_path, to_path, compression) + if remove_finished: + os.remove(from_path) + + return to_path + + +def download_and_extract_archive( + url: str, + download_root: str, + extract_root: Optional[str] = None, + filename: Optional[str] = None, + md5: Optional[str] = None, + remove_finished: bool = False, +) -> None: + download_root = os.path.expanduser(download_root) + if extract_root is None: + extract_root = download_root + if not filename: + filename = os.path.basename(url) + + download_url(url, download_root, filename, md5) + + archive = os.path.join(download_root, filename) + print(f"Extracting {archive} to {extract_root}") + extract_archive(archive, extract_root, remove_finished) + + +def iterable_to_str(iterable: Iterable) -> str: + return "'" + "', '".join([str(item) for item in iterable]) + "'" + + +T = TypeVar("T", str, bytes) + + +def verify_str_arg( + value: T, + arg: Optional[str] = None, + valid_values: Optional[Iterable[T]] = None, + custom_msg: Optional[str] = None, +) -> T: + if not isinstance(value, torch._six.string_classes): + if arg is None: + msg = "Expected type str, but got type {type}." + else: + msg = "Expected type str for argument {arg}, but got type {type}." + msg = msg.format(type=type(value), arg=arg) + raise ValueError(msg) + + if valid_values is None: + return value + + if value not in valid_values: + if custom_msg is not None: + msg = custom_msg + else: + msg = "Unknown value '{value}' for argument {arg}. Valid values are {{{valid_values}}}." + msg = msg.format(value=value, arg=arg, valid_values=iterable_to_str(valid_values)) + raise ValueError(msg) + + return value + + +def _read_pfm(file_name: str, slice_channels: int = 2) -> np.ndarray: + """Read file in .pfm format. Might contain either 1 or 3 channels of data. + + Args: + file_name (str): Path to the file. + slice_channels (int): Number of channels to slice out of the file. + Useful for reading different data formats stored in .pfm files: Optical Flows, Stereo Disparity Maps, etc. + """ + + with open(file_name, "rb") as f: + header = f.readline().rstrip() + if header not in [b"PF", b"Pf"]: + raise ValueError("Invalid PFM file") + + dim_match = re.match(rb"^(\d+)\s(\d+)\s$", f.readline()) + if not dim_match: + raise Exception("Malformed PFM header.") + w, h = (int(dim) for dim in dim_match.groups()) + + scale = float(f.readline().rstrip()) + if scale < 0: # little-endian + endian = "<" + scale = -scale + else: + endian = ">" # big-endian + + data = np.fromfile(f, dtype=endian + "f") + + pfm_channels = 3 if header == b"PF" else 1 + + data = data.reshape(h, w, pfm_channels).transpose(2, 0, 1) + data = np.flip(data, axis=1) # flip on h dimension + data = data[:slice_channels, :, :] + return data.astype(np.float32) \ No newline at end of file diff --git a/python/jittor/compatibility/vision/datasets/vision.py b/python/jittor/compatibility/vision/datasets/vision.py new file mode 100644 index 00000000..d71dc2a5 --- /dev/null +++ b/python/jittor/compatibility/vision/datasets/vision.py @@ -0,0 +1,104 @@ +import os +from typing import Any, Callable, List, Optional, Tuple + +import torch +import torch.utils.data as data + +from ..utils import _log_api_usage_once + + +class VisionDataset(data.Dataset): + """ + Base Class For making datasets which are compatible with torchvision. + It is necessary to override the ``__getitem__`` and ``__len__`` method. + Args: + root (string): Root directory of dataset. + transforms (callable, optional): A function/transforms that takes in + an image and a label and returns the transformed versions of both. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + .. note:: + :attr:`transforms` and the combination of :attr:`transform` and :attr:`target_transform` are mutually exclusive. + """ + + _repr_indent = 4 + + def __init__( + self, + root: str, + transforms: Optional[Callable] = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + ) -> None: + self.root = root + + has_transforms = transforms is not None + has_separate_transform = transform is not None or target_transform is not None + if has_transforms and has_separate_transform: + raise ValueError("Only transforms or transform/target_transform can be passed as argument") + + # for backwards-compatibility + self.transform = transform + self.target_transform = target_transform + + if has_separate_transform: + transforms = StandardTransform(transform, target_transform) + self.transforms = transforms + + def __getitem__(self, index: int) -> Any: + """ + Args: + index (int): Index + Returns: + (Any): Sample and meta data, optionally transformed by the respective transforms. + """ + raise NotImplementedError + + def __len__(self) -> int: + raise NotImplementedError + + def __repr__(self) -> str: + head = "Dataset " + self.__class__.__name__ + body = [f"Number of datapoints: {self.__len__()}"] + if self.root is not None: + body.append(f"Root location: {self.root}") + body += self.extra_repr().splitlines() + if hasattr(self, "transforms") and self.transforms is not None: + body += [repr(self.transforms)] + lines = [head] + [" " * self._repr_indent + line for line in body] + return "\n".join(lines) + + def _format_transform_repr(self, transform: Callable, head: str) -> List[str]: + lines = transform.__repr__().splitlines() + return [f"{head}{lines[0]}"] + ["{}{}".format(" " * len(head), line) for line in lines[1:]] + + def extra_repr(self) -> str: + return "" + + +class StandardTransform: + def __init__(self, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None) -> None: + self.transform = transform + self.target_transform = target_transform + + def __call__(self, input: Any, target: Any) -> Tuple[Any, Any]: + if self.transform is not None: + input = self.transform(input) + if self.target_transform is not None: + target = self.target_transform(target) + return input, target + + def _format_transform_repr(self, transform: Callable, head: str) -> List[str]: + lines = transform.__repr__().splitlines() + return [f"{head}{lines[0]}"] + ["{}{}".format(" " * len(head), line) for line in lines[1:]] + + def __repr__(self) -> str: + body = [self.__class__.__name__] + if self.transform is not None: + body += self._format_transform_repr(self.transform, "Transform: ") + if self.target_transform is not None: + body += self._format_transform_repr(self.target_transform, "Target transform: ") + + return "\n".join(body) \ No newline at end of file diff --git a/python/jittor/compatibility/vision/transforms.py b/python/jittor/compatibility/vision/transforms.py new file mode 100644 index 00000000..416057c7 --- /dev/null +++ b/python/jittor/compatibility/vision/transforms.py @@ -0,0 +1 @@ +from jittor.transform import * \ No newline at end of file diff --git a/python/jittor/compatibility/vision/utils.py b/python/jittor/compatibility/vision/utils.py new file mode 100644 index 00000000..4be36c64 --- /dev/null +++ b/python/jittor/compatibility/vision/utils.py @@ -0,0 +1,582 @@ +import collections +import math +import pathlib +import warnings +from itertools import repeat +from types import FunctionType +from typing import Any, BinaryIO, List, Optional, Tuple, Union + +import numpy as np +import torch +from PIL import Image, ImageColor, ImageDraw, ImageFont + +__all__ = [ + "make_grid", + "save_image", + "draw_bounding_boxes", + "draw_segmentation_masks", + "draw_keypoints", + "flow_to_image", +] + + +@torch.no_grad() +def make_grid( + tensor: Union[torch.Tensor, List[torch.Tensor]], + nrow: int = 8, + padding: int = 2, + normalize: bool = False, + value_range: Optional[Tuple[int, int]] = None, + scale_each: bool = False, + pad_value: float = 0.0, + **kwargs, +) -> torch.Tensor: + """ + Make a grid of images. + + Args: + tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W) + or a list of images all of the same size. + nrow (int, optional): Number of images displayed in each row of the grid. + The final grid size is ``(B / nrow, nrow)``. Default: ``8``. + padding (int, optional): amount of padding. Default: ``2``. + normalize (bool, optional): If True, shift the image to the range (0, 1), + by the min and max values specified by ``value_range``. Default: ``False``. + value_range (tuple, optional): tuple (min, max) where min and max are numbers, + then these numbers are used to normalize the image. By default, min and max + are computed from the tensor. + range (tuple. optional): + .. warning:: + This parameter was deprecated in ``0.12`` and will be removed in ``0.14``. Please use ``value_range`` + instead. + scale_each (bool, optional): If ``True``, scale each image in the batch of + images separately rather than the (min, max) over all images. Default: ``False``. + pad_value (float, optional): Value for the padded pixels. Default: ``0``. + + Returns: + grid (Tensor): the tensor containing grid of images. + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(make_grid) + if not torch.is_tensor(tensor): + if isinstance(tensor, list): + for t in tensor: + if not torch.is_tensor(t): + raise TypeError(f"tensor or list of tensors expected, got a list containing {type(t)}") + else: + raise TypeError(f"tensor or list of tensors expected, got {type(tensor)}") + + if "range" in kwargs.keys(): + warnings.warn( + "The parameter 'range' is deprecated since 0.12 and will be removed in 0.14. " + "Please use 'value_range' instead." + ) + value_range = kwargs["range"] + + # if list of tensors, convert to a 4D mini-batch Tensor + if isinstance(tensor, list): + tensor = torch.stack(tensor, dim=0) + + if tensor.dim() == 2: # single image H x W + tensor = tensor.unsqueeze(0) + if tensor.dim() == 3: # single image + if tensor.size(0) == 1: # if single-channel, convert to 3-channel + tensor = torch.cat((tensor, tensor, tensor), 0) + tensor = tensor.unsqueeze(0) + + if tensor.dim() == 4 and tensor.size(1) == 1: # single-channel images + tensor = torch.cat((tensor, tensor, tensor), 1) + + if normalize is True: + tensor = tensor.clone() # avoid modifying tensor in-place + if value_range is not None and not isinstance(value_range, tuple): + raise TypeError("value_range has to be a tuple (min, max) if specified. min and max are numbers") + + def norm_ip(img, low, high): + img.clamp_(min=low, max=high) + img.sub_(low).div_(max(high - low, 1e-5)) + + def norm_range(t, value_range): + if value_range is not None: + norm_ip(t, value_range[0], value_range[1]) + else: + norm_ip(t, float(t.min()), float(t.max())) + + if scale_each is True: + for t in tensor: # loop over mini-batch dimension + norm_range(t, value_range) + else: + norm_range(tensor, value_range) + + if not isinstance(tensor, torch.Tensor): + raise TypeError("tensor should be of type torch.Tensor") + if tensor.size(0) == 1: + return tensor.squeeze(0) + + # make the mini-batch of images into a grid + nmaps = tensor.size(0) + xmaps = min(nrow, nmaps) + ymaps = int(math.ceil(float(nmaps) / xmaps)) + height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding) + num_channels = tensor.size(1) + grid = tensor.new_full((num_channels, height * ymaps + padding, width * xmaps + padding), pad_value) + k = 0 + for y in range(ymaps): + for x in range(xmaps): + if k >= nmaps: + break + # Tensor.copy_() is a valid method but seems to be missing from the stubs + # https://pytorch.org/docs/stable/tensors.html#torch.Tensor.copy_ + grid.narrow(1, y * height + padding, height - padding).narrow( # type: ignore[attr-defined] + 2, x * width + padding, width - padding + ).copy_(tensor[k]) + k = k + 1 + return grid + + +@torch.no_grad() +def save_image( + tensor: Union[torch.Tensor, List[torch.Tensor]], + fp: Union[str, pathlib.Path, BinaryIO], + format: Optional[str] = None, + **kwargs, +) -> None: + """ + Save a given Tensor into an image file. + + Args: + tensor (Tensor or list): Image to be saved. If given a mini-batch tensor, + saves the tensor as a grid of images by calling ``make_grid``. + fp (string or file object): A filename or a file object + format(Optional): If omitted, the format to use is determined from the filename extension. + If a file object was used instead of a filename, this parameter should always be used. + **kwargs: Other arguments are documented in ``make_grid``. + """ + + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(save_image) + grid = make_grid(tensor, **kwargs) + # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer + ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() + im = Image.fromarray(ndarr) + im.save(fp, format=format) + + +@torch.no_grad() +def draw_bounding_boxes( + image: torch.Tensor, + boxes: torch.Tensor, + labels: Optional[List[str]] = None, + colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None, + fill: Optional[bool] = False, + width: int = 1, + font: Optional[str] = None, + font_size: Optional[int] = None, +) -> torch.Tensor: + + """ + Draws bounding boxes on given image. + The values of the input image should be uint8 between 0 and 255. + If fill is True, Resulting Tensor should be saved as PNG image. + + Args: + image (Tensor): Tensor of shape (C x H x W) and dtype uint8. + boxes (Tensor): Tensor of size (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format. Note that + the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and + `0 <= ymin < ymax < H`. + labels (List[str]): List containing the labels of bounding boxes. + colors (color or list of colors, optional): List containing the colors + of the boxes or single color for all boxes. The color can be represented as + PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. + By default, random colors are generated for boxes. + fill (bool): If `True` fills the bounding box with specified color. + width (int): Width of bounding box. + font (str): A filename containing a TrueType font. If the file is not found in this filename, the loader may + also search in other directories, such as the `fonts/` directory on Windows or `/Library/Fonts/`, + `/System/Library/Fonts/` and `~/Library/Fonts/` on macOS. + font_size (int): The requested font size in points. + + Returns: + img (Tensor[C, H, W]): Image Tensor of dtype uint8 with bounding boxes plotted. + """ + + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(draw_bounding_boxes) + if not isinstance(image, torch.Tensor): + raise TypeError(f"Tensor expected, got {type(image)}") + elif image.dtype != torch.uint8: + raise ValueError(f"Tensor uint8 expected, got {image.dtype}") + elif image.dim() != 3: + raise ValueError("Pass individual images, not batches") + elif image.size(0) not in {1, 3}: + raise ValueError("Only grayscale and RGB images are supported") + elif (boxes[:, 0] > boxes[:, 2]).any() or (boxes[:, 1] > boxes[:, 3]).any(): + raise ValueError( + "Boxes need to be in (xmin, ymin, xmax, ymax) format. Use torchvision.ops.box_convert to convert them" + ) + + num_boxes = boxes.shape[0] + + if num_boxes == 0: + warnings.warn("boxes doesn't contain any box. No box was drawn") + return image + + if labels is None: + labels: Union[List[str], List[None]] = [None] * num_boxes # type: ignore[no-redef] + elif len(labels) != num_boxes: + raise ValueError( + f"Number of boxes ({num_boxes}) and labels ({len(labels)}) mismatch. Please specify labels for each box." + ) + + if colors is None: + colors = _generate_color_palette(num_boxes) + elif isinstance(colors, list): + if len(colors) < num_boxes: + raise ValueError(f"Number of colors ({len(colors)}) is less than number of boxes ({num_boxes}). ") + else: # colors specifies a single color for all boxes + colors = [colors] * num_boxes + + colors = [(ImageColor.getrgb(color) if isinstance(color, str) else color) for color in colors] + + if font is None: + if font_size is not None: + warnings.warn("Argument 'font_size' will be ignored since 'font' is not set.") + txt_font = ImageFont.load_default() + else: + txt_font = ImageFont.truetype(font=font, size=font_size or 10) + + # Handle Grayscale images + if image.size(0) == 1: + image = torch.tile(image, (3, 1, 1)) + + ndarr = image.permute(1, 2, 0).cpu().numpy() + img_to_draw = Image.fromarray(ndarr) + img_boxes = boxes.to(torch.int64).tolist() + + if fill: + draw = ImageDraw.Draw(img_to_draw, "RGBA") + else: + draw = ImageDraw.Draw(img_to_draw) + + for bbox, color, label in zip(img_boxes, colors, labels): # type: ignore[arg-type] + if fill: + fill_color = color + (100,) + draw.rectangle(bbox, width=width, outline=color, fill=fill_color) + else: + draw.rectangle(bbox, width=width, outline=color) + + if label is not None: + margin = width + 1 + draw.text((bbox[0] + margin, bbox[1] + margin), label, fill=color, font=txt_font) + + return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8) + + +@torch.no_grad() +def draw_segmentation_masks( + image: torch.Tensor, + masks: torch.Tensor, + alpha: float = 0.8, + colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None, +) -> torch.Tensor: + + """ + Draws segmentation masks on given RGB image. + The values of the input image should be uint8 between 0 and 255. + + Args: + image (Tensor): Tensor of shape (3, H, W) and dtype uint8. + masks (Tensor): Tensor of shape (num_masks, H, W) or (H, W) and dtype bool. + alpha (float): Float number between 0 and 1 denoting the transparency of the masks. + 0 means full transparency, 1 means no transparency. + colors (color or list of colors, optional): List containing the colors + of the masks or single color for all masks. The color can be represented as + PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. + By default, random colors are generated for each mask. + + Returns: + img (Tensor[C, H, W]): Image Tensor, with segmentation masks drawn on top. + """ + + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(draw_segmentation_masks) + if not isinstance(image, torch.Tensor): + raise TypeError(f"The image must be a tensor, got {type(image)}") + elif image.dtype != torch.uint8: + raise ValueError(f"The image dtype must be uint8, got {image.dtype}") + elif image.dim() != 3: + raise ValueError("Pass individual images, not batches") + elif image.size()[0] != 3: + raise ValueError("Pass an RGB image. Other Image formats are not supported") + if masks.ndim == 2: + masks = masks[None, :, :] + if masks.ndim != 3: + raise ValueError("masks must be of shape (H, W) or (batch_size, H, W)") + if masks.dtype != torch.bool: + raise ValueError(f"The masks must be of dtype bool. Got {masks.dtype}") + if masks.shape[-2:] != image.shape[-2:]: + raise ValueError("The image and the masks must have the same height and width") + + num_masks = masks.size()[0] + if colors is not None and num_masks > len(colors): + raise ValueError(f"There are more masks ({num_masks}) than colors ({len(colors)})") + + if num_masks == 0: + warnings.warn("masks doesn't contain any mask. No mask was drawn") + return image + + if colors is None: + colors = _generate_color_palette(num_masks) + + if not isinstance(colors, list): + colors = [colors] + if not isinstance(colors[0], (tuple, str)): + raise ValueError("colors must be a tuple or a string, or a list thereof") + if isinstance(colors[0], tuple) and len(colors[0]) != 3: + raise ValueError("It seems that you passed a tuple of colors instead of a list of colors") + + out_dtype = torch.uint8 + + colors_ = [] + for color in colors: + if isinstance(color, str): + color = ImageColor.getrgb(color) + colors_.append(torch.tensor(color, dtype=out_dtype)) + + img_to_draw = image.detach().clone() + # TODO: There might be a way to vectorize this + for mask, color in zip(masks, colors_): + img_to_draw[:, mask] = color[:, None] + + out = image * (1 - alpha) + img_to_draw * alpha + return out.to(out_dtype) + + +@torch.no_grad() +def draw_keypoints( + image: torch.Tensor, + keypoints: torch.Tensor, + connectivity: Optional[List[Tuple[int, int]]] = None, + colors: Optional[Union[str, Tuple[int, int, int]]] = None, + radius: int = 2, + width: int = 3, +) -> torch.Tensor: + + """ + Draws Keypoints on given RGB image. + The values of the input image should be uint8 between 0 and 255. + + Args: + image (Tensor): Tensor of shape (3, H, W) and dtype uint8. + keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoints location for each of the N instances, + in the format [x, y]. + connectivity (List[Tuple[int, int]]]): A List of tuple where, + each tuple contains pair of keypoints to be connected. + colors (str, Tuple): The color can be represented as + PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. + radius (int): Integer denoting radius of keypoint. + width (int): Integer denoting width of line connecting keypoints. + + Returns: + img (Tensor[C, H, W]): Image Tensor of dtype uint8 with keypoints drawn. + """ + + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(draw_keypoints) + if not isinstance(image, torch.Tensor): + raise TypeError(f"The image must be a tensor, got {type(image)}") + elif image.dtype != torch.uint8: + raise ValueError(f"The image dtype must be uint8, got {image.dtype}") + elif image.dim() != 3: + raise ValueError("Pass individual images, not batches") + elif image.size()[0] != 3: + raise ValueError("Pass an RGB image. Other Image formats are not supported") + + if keypoints.ndim != 3: + raise ValueError("keypoints must be of shape (num_instances, K, 2)") + + ndarr = image.permute(1, 2, 0).cpu().numpy() + img_to_draw = Image.fromarray(ndarr) + draw = ImageDraw.Draw(img_to_draw) + img_kpts = keypoints.to(torch.int64).tolist() + + for kpt_id, kpt_inst in enumerate(img_kpts): + for inst_id, kpt in enumerate(kpt_inst): + x1 = kpt[0] - radius + x2 = kpt[0] + radius + y1 = kpt[1] - radius + y2 = kpt[1] + radius + draw.ellipse([x1, y1, x2, y2], fill=colors, outline=None, width=0) + + if connectivity: + for connection in connectivity: + start_pt_x = kpt_inst[connection[0]][0] + start_pt_y = kpt_inst[connection[0]][1] + + end_pt_x = kpt_inst[connection[1]][0] + end_pt_y = kpt_inst[connection[1]][1] + + draw.line( + ((start_pt_x, start_pt_y), (end_pt_x, end_pt_y)), + width=width, + ) + + return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8) + + +# Flow visualization code adapted from https://github.com/tomrunia/OpticalFlow_Visualization +@torch.no_grad() +def flow_to_image(flow: torch.Tensor) -> torch.Tensor: + + """ + Converts a flow to an RGB image. + + Args: + flow (Tensor): Flow of shape (N, 2, H, W) or (2, H, W) and dtype torch.float. + + Returns: + img (Tensor): Image Tensor of dtype uint8 where each color corresponds + to a given flow direction. Shape is (N, 3, H, W) or (3, H, W) depending on the input. + """ + + if flow.dtype != torch.float: + raise ValueError(f"Flow should be of dtype torch.float, got {flow.dtype}.") + + orig_shape = flow.shape + if flow.ndim == 3: + flow = flow[None] # Add batch dim + + if flow.ndim != 4 or flow.shape[1] != 2: + raise ValueError(f"Input flow should have shape (2, H, W) or (N, 2, H, W), got {orig_shape}.") + + max_norm = torch.sum(flow**2, dim=1).sqrt().max() + epsilon = torch.finfo((flow).dtype).eps + normalized_flow = flow / (max_norm + epsilon) + img = _normalized_flow_to_image(normalized_flow) + + if len(orig_shape) == 3: + img = img[0] # Remove batch dim + return img + + +@torch.no_grad() +def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Tensor: + + """ + Converts a batch of normalized flow to an RGB image. + + Args: + normalized_flow (torch.Tensor): Normalized flow tensor of shape (N, 2, H, W) + Returns: + img (Tensor(N, 3, H, W)): Flow visualization image of dtype uint8. + """ + + N, _, H, W = normalized_flow.shape + device = normalized_flow.device + flow_image = torch.zeros((N, 3, H, W), dtype=torch.uint8, device=device) + colorwheel = _make_colorwheel().to(device) # shape [55x3] + num_cols = colorwheel.shape[0] + norm = torch.sum(normalized_flow**2, dim=1).sqrt() + a = torch.atan2(-normalized_flow[:, 1, :, :], -normalized_flow[:, 0, :, :]) / torch.pi + fk = (a + 1) / 2 * (num_cols - 1) + k0 = torch.floor(fk).to(torch.long) + k1 = k0 + 1 + k1[k1 == num_cols] = 0 + f = fk - k0 + + for c in range(colorwheel.shape[1]): + tmp = colorwheel[:, c] + col0 = tmp[k0] / 255.0 + col1 = tmp[k1] / 255.0 + col = (1 - f) * col0 + f * col1 + col = 1 - norm * (1 - col) + flow_image[:, c, :, :] = torch.floor(255 * col) + return flow_image + + +def _make_colorwheel() -> torch.Tensor: + """ + Generates a color wheel for optical flow visualization as presented in: + Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) + URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf. + + Returns: + colorwheel (Tensor[55, 3]): Colorwheel Tensor. + """ + + RY = 15 + YG = 6 + GC = 4 + CB = 11 + BM = 13 + MR = 6 + + ncols = RY + YG + GC + CB + BM + MR + colorwheel = torch.zeros((ncols, 3)) + col = 0 + + # RY + colorwheel[0:RY, 0] = 255 + colorwheel[0:RY, 1] = torch.floor(255 * torch.arange(0, RY) / RY) + col = col + RY + # YG + colorwheel[col : col + YG, 0] = 255 - torch.floor(255 * torch.arange(0, YG) / YG) + colorwheel[col : col + YG, 1] = 255 + col = col + YG + # GC + colorwheel[col : col + GC, 1] = 255 + colorwheel[col : col + GC, 2] = torch.floor(255 * torch.arange(0, GC) / GC) + col = col + GC + # CB + colorwheel[col : col + CB, 1] = 255 - torch.floor(255 * torch.arange(CB) / CB) + colorwheel[col : col + CB, 2] = 255 + col = col + CB + # BM + colorwheel[col : col + BM, 2] = 255 + colorwheel[col : col + BM, 0] = torch.floor(255 * torch.arange(0, BM) / BM) + col = col + BM + # MR + colorwheel[col : col + MR, 2] = 255 - torch.floor(255 * torch.arange(MR) / MR) + colorwheel[col : col + MR, 0] = 255 + return colorwheel + + +def _generate_color_palette(num_objects: int): + palette = torch.tensor([2**25 - 1, 2**15 - 1, 2**21 - 1]) + return [tuple((i * palette) % 255) for i in range(num_objects)] + + +def _log_api_usage_once(obj: Any) -> None: + + """ + Logs API usage(module and name) within an organization. + In a large ecosystem, it's often useful to track the PyTorch and + TorchVision APIs usage. This API provides the similar functionality to the + logging module in the Python stdlib. It can be used for debugging purpose + to log which methods are used and by default it is inactive, unless the user + manually subscribes a logger via the `SetAPIUsageLogger method `_. + Please note it is triggered only once for the same API call within a process. + It does not collect any data from open-source users since it is no-op by default. + For more information, please refer to + * PyTorch note: https://pytorch.org/docs/stable/notes/large_scale_deployments.html#api-usage-logging; + * Logging policy: https://github.com/pytorch/vision/issues/5052; + + Args: + obj (class instance or method): an object to extract info from. + """ + pass + + +def _make_ntuple(x: Any, n: int) -> Tuple[Any, ...]: + """ + Make n-tuple from input x. If x is an iterable, then we just convert it to tuple. + Otherwise we will make a tuple of length n, all with value of x. + reference: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/utils.py#L8 + + Args: + x (Any): input value + n (int): length of the resulting tuple + """ + if isinstance(x, collections.abc.Iterable): + return tuple(x) + return tuple(repeat(x, n)) \ No newline at end of file diff --git a/python/jittor/compile_extern.py b/python/jittor/compile_extern.py index 07b77c22..87caa2b3 100644 --- a/python/jittor/compile_extern.py +++ b/python/jittor/compile_extern.py @@ -610,6 +610,26 @@ def setup_nccl(): nccl_ops = nccl.ops LOG.vv("Get nccl_ops: "+str(dir(nccl_ops))) +def setup_hccl(): + global hccl_ops + + hccl_src_dir = os.path.join(jittor_path, "extern", "acl", "hccl") + hccl_src_files = [] + for r, _, f in os.walk(hccl_src_dir): + for fname in f: + hccl_src_files.append(os.path.join(r, fname)) + + hccl_include_path = os.path.join(os.environ.get("ASCEND_TOOLKIT_HOME"), "aarch64-linux/include/hccl") + hccl_lib_name = os.path.join(os.environ.get("ASCEND_TOOLKIT_HOME"), "aarch64-linux/lib64/libhccl.so") + ctypes.CDLL(hccl_lib_name, dlopen_flags) + + hccl = compile_custom_ops(hccl_src_files, + extra_flags=f" -I\"{hccl_include_path}\" {mpi_compile_flags} ", + return_module=True, dlopen_flags=os.RTLD_GLOBAL | os.RTLD_NOW, + gen_name_="jittor_hccl_core") + hccl_ops = hccl.ops + LOG.vv("Get hccl_ops: "+str(dir(hccl_ops))) + def manual_link(flags): lib_dirs = [] libs = [] @@ -707,8 +727,15 @@ cudnn = cublas = curand = cufft = cusparse = None setup_mpi() rank = mpi.world_rank() if in_mpi else 0 world_size = mpi.world_size() if in_mpi else 1 -setup_nccl() +# if has_acl: +# setup_hccl() +# elif has_cuda: +# setup_nccl() +# setup_cutt() +# setup_cutlass() + +setup_nccl() setup_cutt() setup_cutlass() @@ -723,3 +750,4 @@ setup_cuda_extern() for mod in jit_utils.backends: if mod.install_extern(): break + diff --git a/python/jittor/compiler.py b/python/jittor/compiler.py index 81d43956..d6675971 100644 --- a/python/jittor/compiler.py +++ b/python/jittor/compiler.py @@ -1186,9 +1186,23 @@ make_cache_dir(os.path.join(cache_path, "tmp")) ck_path = os.path.join(cache_path, "checkpoints") make_cache_dir(ck_path) + +ascend_toolkit_home = os.getenv('ASCEND_TOOLKIT_HOME') + # build cache_compile cc_flags += f" -I\"{os.path.join(jittor_path, 'src')}\" " +cc_flags += f" -I\"{os.path.join(jittor_path, 'extern')}\" " +cc_flags += f" -I\"{os.path.join(ascend_toolkit_home, 'include')}\" " +cc_flags += f" -I\"{os.path.join(ascend_toolkit_home, 'include/acl')}\" " +cc_flags += f" -I\"{os.path.join(ascend_toolkit_home, 'include/aclnn')}\" " +cc_flags += f" -I\"{os.path.join(ascend_toolkit_home, 'include/aclnnop')}\" " +cc_flags += f" -L\"{os.path.join(ascend_toolkit_home, 'lib64')}\" " +cc_flags += " -llibascendcl " +cc_flags += " -llibnnopbase " +cc_flags += " -llibopapi " + cc_flags += py_include + check_cache_compile() LOG.v(f"Get cache_compile: {jit_utils.cc}") @@ -1306,6 +1320,10 @@ for v in at_last: registers = [ name for name in files4 if "register" in name ] for name in registers: files4.remove(name) files = registers + files2 + files4 + + +#print(extra_core_files) +#extra_core_files.append("/home/ma-user/work/jittor/python/jittor/extern/acl/aclnn/aclnn.cc") files += extra_core_files for file in jit_utils_core_files: files.remove(file) @@ -1355,6 +1373,10 @@ else: files = [f for f in files if "__data__" not in f or "src" in f.split("__data__")[1]] + +#print(jittor_path) +#print(cc_flags) +#print(files) cc_flags += f" -l\"jit_utils_core{lib_suffix}\" " compile(cc_path, cc_flags+opt_flags, files, 'jittor_core'+extension_suffix) cc_flags += f" -l\"jittor_core{lib_suffix}\" " @@ -1371,7 +1393,7 @@ if has_cuda and is_cuda: nvcc_flags = " " + os.environ.get("nvcc_flags", "") + " " nvcc_flags += convert_nvcc_flags(cc_flags) nvcc_version = list(jit_utils.get_int_version(nvcc_path)) - max_arch = 90 + max_arch = 89 if nvcc_version < [11,]: max_arch = 75 elif nvcc_version < [11,1]: diff --git a/python/jittor/extern/acl/acl_compiler.py b/python/jittor/extern/acl/acl_compiler.py index 22e366eb..8185ee89 100644 --- a/python/jittor/extern/acl/acl_compiler.py +++ b/python/jittor/extern/acl/acl_compiler.py @@ -11,6 +11,24 @@ import ctypes import glob import jittor.compiler as compiler import jittor as jt +import math +import numpy as np + +from typing import Union +from collections.abc import Sequence, Iterable + + +def _ntuple(n): + + def parse(x): + if isinstance(x, Iterable): + return x + return tuple([x] * n) + + return parse + + +_pair = _ntuple(2) has_acl = 0 cc_flags = "" @@ -49,17 +67,34 @@ def install(): recursive=True)) cc_files2 = [] for name in cc_files: - if "acl_op_exec" in name: + # Skip files in hccl directory + if "hccl" in name: + continue + # if "acl_op_exec" in name or "_op_acl.cc" in name: + if "acl_op_exec" in name or "_op_acl.cc" in name or "utils.cc" in name: compiler.extra_core_files.append(name) else: cc_files2.append(name) cc_files = cc_files2 ascend_toolkit_home = os.getenv('ASCEND_TOOLKIT_HOME') - cc_flags += f" -DHAS_CUDA -DIS_ACL \ - -I{ascend_toolkit_home}/include/ \ - -L{ascend_toolkit_home}/lib64/ \ - -I{acl_compiler_home} -lascendcl -lacl_op_compiler " + #print(ascend_toolkit_home) + #print(acl_compiler_home) + cc_flags += f" -MD -DHAS_CUDA -DIS_ACL \ + -I{ascend_toolkit_home}/include/ \ + -I{ascend_toolkit_home}/include/acl/ \ + -I{ascend_toolkit_home}/include/aclnn/ \ + -I{ascend_toolkit_home}/include/aclnnop/ \ + -I{acl_compiler_home} -lascendcl -lacl_op_compiler \ + -I{acl_compiler_home}/aclnn \ + -I{acl_compiler_home}/aclops \ + -L{ascend_toolkit_home}/lib64/" + + cc_flags += " -llibascendcl " + cc_flags += " -llibnnopbase " + cc_flags += " -llibopapi " + + #pdb.set_trace() ctypes.CDLL("libascendcl.so", dlopen_flags) f''' -ltikc_runtime @@ -118,1237 +153,544 @@ def post_process(): jt.flags.amp_reg |= 32 + 4 # 32 keep float16, 4 keep reduce type mod.init_acl_ops() - -def acl_cmd(name: str, inputs: list, output_dtypes: list, output_shapes: list, - attr: dict): - nchw_op = ['MaxPoolWithArgmaxV1', 'MaxPoolGradWithArgmaxV1', 'AvgPoolV2'] - attr_op = [ - 'MaxPoolWithArgmaxV1', 'MaxPoolGradWithArgmaxV1', 'AvgPoolV2', - 'AdaptiveAvgPool2d', 'AdaptiveAvgPool2dGrad', 'ReverseV2' - ] - - input_code = '' - for i in range(len(inputs)): - if name in nchw_op: - input_code += f"op.add(in{i}, true, ACL_FORMAT_NCHW);\n" - else: - input_code += f"op.add(in{i}, true);\n" - - output_code = '' - for i in range(len(output_dtypes)): - if name in nchw_op: - output_code += f"op.add(out{i}, false, ACL_FORMAT_NCHW);\n" - else: - output_code += f"op.add(out{i}, false);\n" - - # add attr to op - attr_code = '' - if name in attr_op: - for k, v in attr.items(): - if isinstance(v, bool): - if v == True: - attr_code += f"op.set_attr(\"{k}\", 1, 1);\n" - else: - attr_code += f"op.set_attr(\"{k}\", 1, 0);\n" - elif isinstance(v, str): - attr_code += f"op.set_attr(\"{k}\", \"{v}\");\n" - elif k == 'divisor_override_value': - attr_code += f"op.set_attr(\"{k}\", int64_t({v}), 0);\n" - else: - v = str(v).replace('[', '{').replace(']', '}') - attr_code += f"op.set_attr(\"{k}\", vector{v});\n" - else: - for k, v in attr.items(): - if isinstance(v, bool): - if v == True: - attr_code += f"op.set_attr(\"{k}\", 1, 1);\n" - else: - attr_code += f"op.set_attr(\"{k}\", 1, 0);\n" - elif isinstance(v, str): - attr_code += f"op.set_attr(\"{k}\", \"{v}\");\n" - else: - attr_code += f"op.set_attr(\"{k}\", int({v}));\n" - - #print("input_code",input_code) - #print("attr_code",attr_code) - import jittor as jt - return jt.code(output_shapes, - output_dtypes, - inputs, - cuda_header=""" - #include - #include - #include - #include - - namespace jittor { - - void printDeviceData(const vector& output_desc, const vector& output_data, const string& name = "", bool input=true) { - LOGir << "name: " << name; - if(input) - LOGir << "is input"; - else - LOGir << "is ouput"; - for (size_t i = 0; i < output_desc.size(); ++i) { - void* base_addr = aclGetDataBufferAddr(output_data[i]); - LOGir << "addr of data[" << i << "] :" << base_addr; - size_t num_dims = aclGetTensorDescNumDims(output_desc[i]); - size_t total_size = 1; - std::vector dims(num_dims); - - std::cout << "shape of data: "; - for (size_t j = 0; j < num_dims; ++j) { - aclGetTensorDescDimV2(output_desc[i], j, &dims[j]); - total_size *= dims[j]; - std::cout << dims[j] << ", "; - } - int evey_batch_size = total_size/dims[0]; - std::cout << std::endl; - - // for(int i= 0; i < dims[0]; i++) { - // evey_batch_size = 16; - // std::vector host_buffer(evey_batch_size); - // void* offset_addr = static_cast(base_addr) + i * evey_batch_size * sizeof(float); - // aclrtMemcpy(host_buffer.data(), evey_batch_size * sizeof(float), offset_addr, evey_batch_size * sizeof(float), ACL_MEMCPY_DEVICE_TO_HOST); - // std::cout << "batch[" << i << "]:"; - // for (size_t k = 0; k < evey_batch_size; ++k) { - // std::cout << host_buffer[k] << ", "; - // } - // std::cout << std::endl; - // } - } - } - - struct AclOpRunner { - string name; - vector input_desc; - vector output_desc; - vector input_data; - vector output_data; - aclopAttr *attr; - vector> input_host; - vector> input_host_32; - - AclOpRunner(const string& name) : name(name) { - attr = aclopCreateAttr(); - } - - ~AclOpRunner() { - for (auto i : input_desc) aclDestroyTensorDesc(i); - for (auto i : output_desc) aclDestroyTensorDesc(i); - for (auto i : input_data) aclDestroyDataBuffer(i); - for (auto i : output_data) aclDestroyDataBuffer(i); - aclopDestroyAttr(attr); - } - - aclDataType get_dtype(NanoString s) { - if (s == ns_float32) return ACL_FLOAT; - if (s == ns_float16) return ACL_FLOAT16; - if (s == ns_int64) return ACL_INT64; - if (s == ns_int32) return ACL_INT32; - if (s == ns_int8) return ACL_INT8; - if (s == ns_int16) return ACL_INT16; - if (s == ns_uint8) return ACL_UINT8; - if (s == ns_uint16) return ACL_UINT16; - if (s == ns_uint32) return ACL_UINT32; - if (s == ns_bool) return ACL_BOOL; - LOGf << "Not supported dtype: " << s; - return ACL_FLOAT; - } - - void add(Var* v, bool is_input, int format=ACL_FORMAT_ND) { - int64_t shape[v->shape.size()]; - for (int i=0; ishape.size(); i++) shape[i] = v->shape[i]; - - auto desc = aclCreateTensorDesc(get_dtype(v->dtype()), v->shape.size(), &shape[0], (aclFormat)format); - aclSetTensorFormat(desc, (aclFormat)format); - aclSetTensorShape(desc, v->shape.size(), &shape[0]); - LOGv << "aclCreateTensorDesc" << (int)get_dtype(v->dtype()) << v->shape.size() << &shape[0] << format; - auto data = aclCreateDataBuffer(v->mem_ptr, v->size); - LOGv << "aclCreateDataBuffer" << v->mem_ptr << v->size; - ASSERT(desc && data); - if (is_input) { - input_desc.push_back(desc); - input_data.push_back(data); - } else { - output_desc.push_back(desc); - output_data.push_back(data); - } - } - - void add_input_host(vector v, int dtype=ACL_UINT64) { - int64_t shape[1]; - shape[0] = v.size(); - auto desc = aclCreateTensorDesc((aclDataType)dtype, 1, &shape[0], ACL_FORMAT_ND); - aclSetTensorPlaceMent(desc, ACL_MEMTYPE_HOST_COMPILE_INDEPENDENT); - LOGv << "aclCreateTensorDesc" << dtype << 1 << &shape[0] << ACL_FORMAT_ND; - auto data = aclCreateDataBuffer(&v[0], v.size()*sizeof(uint64)); - ASSERT(desc && data); - LOGv << "aclCreateDataBuffer" << &v[0] << v.size()*sizeof(uint64); - input_desc.push_back(desc); - input_data.push_back(data); - input_host.emplace_back(move(v)); - LOGv << "move" << input_host.back().data(); - } - - void add_input_host_scalar(vector v, int dtype=ACL_UINT32) { - int64_t shape[1]; - shape[0] = v.size(); - auto x = (int*)&v[0]; - x[0] = (int32)v[0]; - auto desc = aclCreateTensorDesc((aclDataType)dtype, 0, &shape[0], ACL_FORMAT_ND); - aclSetTensorPlaceMent(desc, ACL_MEMTYPE_HOST_COMPILE_INDEPENDENT); - LOGv << "aclCreateTensorDesc" << dtype << 1 << &shape[0] << ACL_FORMAT_ND; - auto data = aclCreateDataBuffer(&v[0], v.size()*sizeof(uint32)); - ASSERT(desc && data); - LOGv << "aclCreateDataBuffer" << &v[0] << v.size()*sizeof(uint32); - input_desc.push_back(desc); - input_data.push_back(data); - input_host.emplace_back(move(v)); - } - - void add_input_host_nv(NanoVector nv, int dtype=ACL_UINT64) { - vector v(nv.size()); - for (int i=0; i v(nv.size()); - for (int i=0; i value) { - // LOGir << "string vector" << "set_attr" << key << value; - CHECK(aclopSetAttrListInt(attr, key.c_str(), value.size(), &value[0])==0); - } - void set_attr(const string& key, string value) { - // LOGir << "string string" << "set_attr" << key << value; - CHECK(aclopSetAttrString(attr, key.c_str(), value.c_str())==0); - } - void set_attr(const char* key, const char* value) { - // LOGir << "char" << "set_attr" << key << value; - CHECK(aclopSetAttrString(attr, key, value)==0); - } - - void run() { - // printDeviceData(input_desc, input_data, name); - - LOGv << "run" << name << input_desc.size() << output_desc.size(); - if (!PyGILState_Check()) { - ASSERT(0==aclopCompileAndExecuteV2(name.c_str(), input_desc.size(), &input_desc[0], &input_data[0], output_desc.size(), &output_desc[0], &output_data[0], attr, ACL_ENGINE_SYS, ACL_COMPILE_SYS, NULL, aclstream)); - } else { - int ret; - Py_BEGIN_ALLOW_THREADS - ret = aclopCompileAndExecuteV2(name.c_str(), input_desc.size(), &input_desc[0], &input_data[0], output_desc.size(), &output_desc[0], &output_data[0], attr, ACL_ENGINE_SYS, ACL_COMPILE_SYS, NULL, aclstream); - Py_END_ALLOW_THREADS - if (ret != 0) - LOGf << "aclopCompileAndExecuteV2" << name << "failed return" << ret; - } - ASSERT(0==aclrtSynchronizeDevice()); - - // printDeviceData(output_desc, output_data, name, false); - } - }; - - } - """, - cuda_src=f""" - // aclop - AclOpRunner op("{name}"); - {input_code} - {output_code} - {attr_code} - op.run();""") - - def change_function(): import jittor as jt from jittor import Function + from .aclops.flashattention_op import FlashAttentionACL + from .aclops.conv_op import ConvACL + from .aclops.pool_op import PoolACL + from .aclops.nantonum_op import NanToNumACL + from .aclops.stack_op import StackACL + from .aclops.rope_op import RopeACL + from .aclops.softmax_op import SoftmaxACL + from .aclops.sigmoid_op import SigmoidACL + from .aclops.silu_op import SiLUACL + from .aclops.dropout_op import DropoutACL + from .aclops.relu_op import LeakyReLUACL + from .aclops.flip_op import FlipACL + from .aclops.concat_op import ConcatACL + from .aclops.gather_scatter_op import GatherACL + from .aclops.cumsum_op import CumsumACL + from .aclops.index_op import IndexACL + from .aclops.gather_scatter_op import ScatterACL + from .aclops.where_op import WhereACL + from .aclops.where_op import NonzeroACL + from .aclops.floor_op import FloorIntACL + from .aclops.getitem_op import GetItemACL + from .aclops.setitem_op import SetItemACL + from .aclops.bmm_op import BmmACL + from .aclops.matmul_op import MatmulACL + from .aclops.transpose_op import TransPoseACL - class IndexACL(Function): + from .aclops.triu_op import TriuACL - def __init__(self): - super(IndexACL, self).__init__() + def triu_acl(x, diagonal=0): + return TriuACL()(x, diagonal) - def execute(self, inshape: list, dim, dtype="int32"): - # zeros a tensor, shape is inshape, dtype is dtype - dim_input = dim - if dim == None: - dim = [i for i in range(len(inshape))] - elif type(dim) == int: - dim = [dim] - results = [] - for d in dim: - max_len = inshape[d] - tmp = jt.zeros(max_len, dtype=dtype) - result = acl_cmd( - "Range", [jt.Var(0), jt.Var(max_len), - jt.Var(1)], - output_dtypes=[tmp.dtype], - output_shapes=[tmp.shape], - attr={})[0] - broadcast_dim = [] - for i in range(len(inshape)): - if i != d: - broadcast_dim.append(i) - result = jt.broadcast(result, - shape=inshape, - dims=broadcast_dim) - results.append(result) - if len(results) != 1 or dim_input == None: - return tuple(results) - else: - return results[0] + from .aclops.conv_op import ConvACL - def grad(self, grad_output): - return grad_output + def conv_acl(x, + weight, + bias=None, + stride=1, + padding=0, + dilation=1, + groups=1): + return ConvACL()(x, weight, bias, stride, padding, dilation, groups) - class PoolACL(Function): - - def get_paddings(self): - pad_top = self.padding[0] - pad_left = self.padding[1] - H = self.input.shape[-2] - W = self.input.shape[-1] - - totalH = H + 2 * self.padding[0] - self.kernel_size[0] - totalW = W + 2 * self.padding[1] - self.kernel_size[1] - - kH = (totalH + self.stride[0] - - 1) // self.stride[0] + 1 if self.attr[ - 'ceil_mode'] else totalH // self.stride[0] + 1 - kW = (totalW + self.stride[1] - - 1) // self.stride[1] + 1 if self.attr[ - 'ceil_mode'] else totalW // self.stride[1] + 1 - - if self.attr['ceil_mode']: - if (kH - 1) * self.stride[0] >= H + self.padding[0]: - kH -= 1 - need_pad_h = (kH - - 1) * self.stride[0] + self.kernel_size[0] - H - pad_top = need_pad_h - self.padding[0] - if (kW - 1) * self.stride[1] >= W + self.padding[1]: - kW -= 1 - need_pad_w = (kW - - 1) * self.stride[1] + self.kernel_size[1] - W - pad_left = need_pad_w - self.padding[1] - - pads = [self.padding[0], pad_top, self.padding[1], pad_left] - return pads + class Conv2D(jt.nn.Module): def __init__(self, + in_channels, + out_channels, kernel_size, - stride=None, + stride=1, padding=0, - dilation=None, - return_indices=None, - ceil_mode=False, - count_include_pad=True, - op='maximum'): - super(PoolACL, self).__init__() - # set attr + dilation=1, + groups=1, + bias=True): + if in_channels <= 0: + raise ValueError( + f"in_channels must be greater than zero, got {in_channels}" + ) + if out_channels <= 0: + raise ValueError( + f"out_channels must be greater than zero, got {out_channels}" + ) + if groups <= 0: + raise ValueError( + f"groups must must be greater than zero, got {groups}") + assert in_channels % groups == 0, 'in_channels must be divisible by groups' + assert out_channels % groups == 0, 'out_channels must be divisible by groups' + if isinstance(kernel_size, tuple): + for size in kernel_size: + if size <= 0: + raise ValueError( + f"kernel_size must be greater than zero, got {kernel_size}" + ) + else: + if kernel_size <= 0: + raise ValueError( + f"kernel_size must be greater than zero, got {kernel_size}" + ) + if isinstance(stride, tuple): + for size in stride: + if size <= 0: + raise ValueError( + f"stride must be greater than zero, got {stride}") + else: + if stride <= 0: + raise ValueError( + f"stride must be greater than zero, got {stride}") + if isinstance(padding, tuple): + for size in padding: + if size < 0: + raise ValueError( + f"padding must be nonnegative, got {padding}") + else: + if padding < 0: + raise ValueError( + f"padding must be nonnegative, got {padding}") + if isinstance(dilation, tuple): + for size in dilation: + if size <= 0: + raise ValueError( + f"dilation must be greater than zero, got {dilation}" + ) + else: + if dilation <= 0: + raise ValueError( + f"dilation must be greater than zero, got {dilation}") + self.in_channels = in_channels + self.out_channels = out_channels self.kernel_size = kernel_size if isinstance( kernel_size, tuple) else (kernel_size, kernel_size) - stride = stride if stride else kernel_size self.stride = stride if isinstance(stride, tuple) else (stride, stride) self.padding = padding if isinstance(padding, tuple) else (padding, padding) - dilation = dilation if dilation else 1 self.dilation = dilation if isinstance( dilation, tuple) else (dilation, dilation) - attr = {} + self.groups = groups + self.is_depthwise_conv = self.groups == self.out_channels and self.groups == self.in_channels + if self.is_depthwise_conv and jt.flags.use_cuda and jt.compiler.is_cuda: + self.depthwise_conv = jt.nn.DepthwiseConv( + stride, padding, dilation) + Kh, Kw = self.kernel_size - self.return_indices = return_indices - self.uint16 = jt.Var(1).int32().dtype - self.op = op - - if op == 'mean': - attr['exclusive'] = not count_include_pad - attr['global_pooling'] = False - attr['divisor_override_value'] = 0 - attr['ksize'] = [ - 1, 1, self.kernel_size[0], self.kernel_size[1] - ] - attr['strides'] = [1, 1, self.stride[0], self.stride[1]] - attr['ceil_mode'] = ceil_mode - attr['padding_mode'] = 'CALCULATED' - attr['data_format'] = 'NCHW' - elif op == 'maximum': - attr['ksize'] = [ - 1, self.kernel_size[0], self.kernel_size[1], 1 - ] - attr['strides'] = [1, self.stride[0], self.stride[1], 1] - attr['pads'] = [1, self.padding[0], self.padding[1], 1] - attr['dilation'] = [1, self.dilation[0], self.dilation[1], 1] - # attr['ceil_mode'] = ceil_mode - - self.attr = attr - - def execute(self, input): - - # create input - input_shape = input.shape - input_dtype = input.dtype - - self.input = input - # create output - output_shape = [ - input_shape[0], input_shape[1], - (input_shape[2] + 2 * self.padding[0] - self.dilation[0] * - (self.kernel_size[0] - 1) - 1) // self.stride[0] + 1, - (input_shape[3] + 2 * self.padding[1] - self.dilation[1] * - (self.kernel_size[1] - 1) - 1) // self.stride[1] + 1 - ] - output_dtype = input_dtype - - if self.op == 'mean': - self.attr['pads'] = self.get_paddings() - result = acl_cmd("AvgPoolV2", [input], - output_dtypes=[output_dtype], - output_shapes=[output_shape], - attr=self.attr) - elif self.op == 'maximum': - result = acl_cmd("MaxPoolWithArgmaxV1", [input], - output_dtypes=[output_dtype, self.uint16], - output_shapes=[output_shape, output_shape], - attr=self.attr) + # self.weight = init.relu_invariant_gauss([out_channels, in_channels//groups, Kh, Kw], dtype="float", mode="fan_out") + self.weight = jt.init.invariant_uniform( + [out_channels, in_channels // groups, Kh, Kw], dtype="float") + if bias: + fan = 1 + for i in self.weight.shape[1:]: + fan *= i + bound = 1 / math.sqrt(fan) + self.bias = jt.init.uniform([out_channels], + dtype="float", + low=-bound, + high=bound) else: - raise ValueError('no this type pool') + self.bias = None - if self.op == 'maximum': - self.index = result[1] + def execute(self, x): + ret = jt.nn.conv2d(x, self.weight, self.bias, self.stride, + self.padding, self.dilation, self.groups) + return ret - if self.return_indices: - return result[0], result[1] + + from .aclops.flip_op import FlipACL + def flip_acl(x, dim): + return FlipACL()(x, dim) + + from .aclops.concat_op import ConcatACL + def concat(x, dim=0): + return ConcatACL()(x, dim) + + from .aclops.gather_scatter_op import GatherACL + + def gather_acl(input, dim, index): + return GatherACL()(input, dim, index) + + def any_acl(input, dim=None): + if dim is None: + if jt.sum(input != 0).item() > 0: + return jt.array([True]) else: - return result[0] + return jt.array([False]) + else: + return jt.sum(input != 0, dim=dim) > 0 - def grad(self, grad_output): - if self.op == 'maximum': - grad_input = acl_cmd("MaxPoolGradWithArgmaxV1", - [self.input, grad_output, self.index], - output_dtypes=[grad_output.dtype], - output_shapes=[self.input.shape], - attr=self.attr)[0] - elif self.op == 'mean': - grad_input = acl_cmd("AvgPoolV2", - [self.input, grad_output, self.index], - output_dtypes=[grad_output.dtype], - output_shapes=[self.input.shape], - attr=self.attr)[0] - else: - grad_input = None - return grad_input + from .aclops.cumsum_op import CumsumACL - class BmmACL(Function): + def cumsum_acl(input, dim=-1): + return CumsumACL()(input, dim) - def __init__(self, adj_x1=False, adj_x2=False): - super(BmmACL, self).__init__() - self.adj_x1 = adj_x1 - self.adj_x2 = adj_x2 + def cumprod_acl(x, dim=None): + x = jt.log(x) + x = cumsum_acl(x, dim=dim) + return jt.exp(x) - def execute(self, x1, x2): - self.input = [x1, x2] - result = acl_cmd("BatchMatMul", [x1, x2], - output_dtypes=[x1.dtype], - output_shapes=[x1.shape[:-1] + x2.shape[-1:]], - attr={})[0] - return result + from .aclops.index_op import IndexACL - def grad(self, grad_output): - x1, x2 = self.input - grad_x1 = acl_cmd( - "BatchMatMul", [grad_output, x2.transpose(-2, -1)], - output_dtypes=[x1.dtype], - output_shapes=[grad_output.shape[:-1] + x1.shape[-1:]], - attr={})[0] - grad_x2 = acl_cmd( - "BatchMatMul", [x1.transpose(-2, -1), grad_output], - output_dtypes=[x2.dtype], - output_shapes=[x2.shape[:-1] + grad_output.shape[-1:]], - attr={})[0] - return grad_x1, grad_x2 + def index_acl(inshape: Union[jt.Var, list], dim=None, dtype="int32"): + if isinstance(inshape, jt.Var): + inshape = inshape.shape + return IndexACL()(inshape, dim, dtype) - class MatmulACL(Function): + from .aclops.gather_scatter_op import ScatterACL + def scatter_acl(input, dim, index, src, reduce='void'): + return ScatterACL()(input, dim, index, src, reduce) - def __init__(self, adj_x1=False, adj_x2=False): - super(MatmulACL, self).__init__() - self.adj_x1 = adj_x1 - self.adj_x2 = adj_x2 + from .aclops.where_op import WhereACL - def execute(self, x1, x2): - self.input = [x1, x2] - if len(x1.shape) > 2 or len(x2.shape) > 2: - result = acl_cmd("BatchMatMul", [x1, x2], - output_dtypes=[x1.dtype], - output_shapes=[x1.shape[:-1] + x2.shape[-1:]], - attr={})[0] - else: - result = acl_cmd("MatMul", [x1, x2], - output_dtypes=[x1.dtype], - output_shapes=[x1.shape[:-1] + x2.shape[-1:]], - attr={})[0] - return result + def where_acl(condition, x=None, y=None): + return WhereACL()(condition, x, y) - def grad(self, grad_output): - x1, x2 = self.input - if len(x1.shape) > 2 or len(x2.shape) > 2: - grad_x1 = acl_cmd( - "BatchMatMul", - [grad_output, x2.transpose(-2, -1)], - output_dtypes=[x1.dtype], - output_shapes=[grad_output.shape[:-1] + x1.shape[-1:]], - attr={})[0] - grad_x2 = acl_cmd( - "BatchMatMul", [x1.transpose(-2, -1), grad_output], - output_dtypes=[x2.dtype], - output_shapes=[x2.shape[:-1] + grad_output.shape[-1:]], - attr={})[0] - else: - grad_x1 = acl_cmd( - "MatMul", [grad_output, x2.transpose(-2, -1)], - output_dtypes=[x1.dtype], - output_shapes=[grad_output.shape[:-1] + x1.shape[-1:]], - attr={})[0] - grad_x2 = acl_cmd( - "MatMul", [x1.transpose(-2, -1), grad_output], - output_dtypes=[x2.dtype], - output_shapes=[x2.shape[:-1] + grad_output.shape[-1:]], - attr={})[0] - return grad_x1, grad_x2 + from .aclops.where_op import NonzeroACL - class GetItem(Function): + def nonzero_acl(x): + return NonzeroACL()(x) - def __init__(self): - super(GetItem, self).__init__() - self.type_ = 'index' + from .aclops.floor_op import FloorIntACL - def stride(self, x, dim): - stride = 1 - for i in range(dim + 1, len(x.shape)): - stride *= x.shape[i] - return stride + def floor_int_acl(x): + return FloorIntACL()(x) - def execute(self, x, slices, return_x=None): - if isinstance(slices, jt.Var) or isinstance(slices, tuple): - if isinstance(slices, jt.Var): - slices = (slices, ) - if isinstance(slices[0], jt.Var): - slices_len = len(slices) - masks = jt.ones(slices_len, dtype=jt.int64) - output = slices[0].shape - output += x.shape[slices_len:] - input_ = [x, masks, jt.Var(list(output)).int64()] - for i in range(slices_len): - input_.append(slices[i].int32()) - result = acl_cmd("Index", - input_, - output_dtypes=[x.dtype], - output_shapes=[output], - attr={})[0] - self.shape = x.shape - self.sizes = list(output) - self.type_ = 'index' - self.slices = slices - # self.strides - return result + from .aclops.getitem_op import GetItemACL - # use AsStrided operator to implement the getitem function - # get the shape and stride of the input tensor - x_dim = len(x.shape) - # int type - if not isinstance(slices, tuple): - slices = (slices, ) + def getitem_acl(x, slices, return_x=None): + # Transform numpy int to int + if isinstance(slices, (np.int8, np.int16, np.int32, np.int64)): + slices = int(slices) + if hasattr(np, 'int128') and isinstance(slices, np.int128): + slices = int(slices) + if hasattr(np, 'int256') and isinstance(slices, np.int256): + slices = int(slices) - if len(slices) < x_dim: - slices += (slice(None, None, None), ) * (x_dim - len(slices)) + ## If not related to `None`, directly use `GetItemACL` + if slices is not None and (not isinstance(slices, Iterable) + or all([s is not None for s in slices])): + return GetItemACL()(x, slices, return_x) - self.inputs = [x, slices] + ## If related to `None`, filter out `None` first, then use `GetItemACL`, and finally insert `None` (new dimensions) back - sizes = [] - strides = [] - offset = 0 + # Transform to tuple + if isinstance(slices, int) or isinstance(slices, slice): + slices = (slices, ) + assert isinstance(slices, tuple) - for dim, s in enumerate(slices): + def get_insert_positions(slices): + result = [] + pos = 0 + + not_none_cnt = len(slices) - slices.count(None) + for s in slices: if isinstance(s, int): - if s < 0: # Handle negative indices. - s += x.shape[dim] - offset += s * self.stride(x, dim) - elif isinstance(s, slice): - # Unpack the slice - start, stop, step = s.indices(x.size(dim)) - size = (stop - start - 1) // step + 1 - stride = self.stride(x, dim) * step - offset += start * self.stride(x, dim) - sizes.append(size) - strides.append(stride) + continue + elif s is None: + result.append(pos) + pos += 1 + elif s == Ellipsis: + pos += 1 + x.ndim - not_none_cnt else: - raise ValueError("Invalid slice type") + pos += 1 - if not sizes: - sizes = [1] - strides = [0] - # AsStrided same with as_strided of pytorch - self.sizes = sizes - self.strides = strides - self.offset = offset - self.shape = x.shape - self.type_ = 'as_strided' - result = acl_cmd( - "AsStrided", - [x, jt.Var(sizes), - jt.Var(strides), - jt.Var(offset)], - output_dtypes=[x.dtype], - output_shapes=[jt.empty(sizes).shape], - attr={})[0] return result - def grad(self, grad_output): - if self.type_ == 'as_strided': - result = jt.zeros(self.shape, dtype=grad_output.dtype) - sizes = list(grad_output.shape) - strides = [ - self.stride(grad_output, dim) - for dim in range(len(grad_output.shape)) - ] - result = acl_cmd("ViewCopy", [ - result, - jt.Var(self.sizes), - jt.Var(self.strides), - jt.Var(self.offset), grad_output, - jt.Var(sizes), - jt.Var(strides), - jt.Var(0) - ], - output_dtypes=[result.dtype], - output_shapes=[result.shape], - attr={})[0] - elif self.type_ == 'index': - #TODO: use IndexPutV2 to implement the grad function - assert len(self.slices) == 1 - index = self.slices[0] - input = jt.zeros(self.shape, dtype=grad_output.dtype) - input_flatten = input.reshape(input.shape[0], -1) - index_flatten = index.reshape(-1).unsqueeze(-1).repeat( - 1, input_flatten.shape[1]) - grad_output_flatten = grad_output.reshape(index.numel(), -1) - result = acl_cmd( - "ScatterElements", - [input_flatten, index_flatten, grad_output_flatten], - output_dtypes=[input.dtype], - output_shapes=[input.shape], - attr={ - 'axis': 0, - 'reduction': 'add' - })[0] - result = result.reshape(self.shape) - # result = jt.zeros(self.shape, dtype=grad_output.dtype) - # # masks = jt.ones(len(self.slices), dtype=jt.int64) - # masks = jt.array([1,1], dtype=jt.int64) - # expand_masks = jt.array([1,1], dtype=jt.int64) - # inputs_ = [result,grad_output,masks,expand_masks] - # slices_len = len(self.slices) - # for i in range(slices_len): - # inputs_.append(self.slices[i].int64()) - # # breakpoint() - # jt.sync_all(True) - # print(inputs_) - # result_ = acl_cmd("IndexPutV2", inputs_, - # output_dtypes=[result.dtype], - # output_shapes=[result.shape], - # attr={"accumulate":True})[0] - # result = result_ - else: - raise ValueError("Invalid slice type") - result.sync() - return result, None + insert_positions = get_insert_positions(slices) + slices_without_none = tuple(s for s in slices if s is not None) + result = GetItemACL()(x, slices_without_none, return_x) - class ConcatACL(Function): + for i in insert_positions: + result = result.unsqueeze(i) + + return result + + + from .aclops.setitem_op import SetItemACL + + def setitem_acl(x, slices, value): + res = SetItemACL()(x, slices, value) + return x.assign(res) + + + from .aclops.bmm_op import BmmACL + + def bmm_acl(x1, x2): + return BmmACL()(x1, x2) + + def bmm_transpose_acl(x1, x2): + return BmmACL(True)(x1, x2) + + + from .aclops.matmul_op import MatmulACL + + def matmul_acl(x1, x2): + return MatmulACL()(x1, x2) + + def matmul_transpose_acl(x1, x2): + return MatmulACL(True)(x1, x2) + + from .aclops.transpose_op import TransPoseACL + + def transpose_acl(x, *dim): + return TransPoseACL()(x, *dim) + + from .aclops.relu_op import ReLUACL + class ReLU(jt.nn.Module): def __init__(self): - super(ConcatACL, self).__init__() + super(ReLU, self).__init__() - def execute(self, input_tensors, dim=0): - self.input = input_tensors - for i in range(len(input_tensors)): - if input_tensors[i].dtype != input_tensors[0].dtype: - raise ValueError( - "All input tensors must have the same dtype") - if input_tensors[i].shape[:dim] != input_tensors[ - 0].shape[:dim] or input_tensors[i].shape[ - dim + 1:] != input_tensors[0].shape[dim + 1:]: - raise ValueError( - "All input tensors must have the same shape") - result = acl_cmd( - "ConcatD", - input_tensors, - output_dtypes=[input_tensors[0].dtype], - output_shapes=[ - jt.empty(self.calculate_output_shape(input_tensors, - dim)).shape - ], - attr={ - "N": len(input_tensors), - "concat_dim": dim - })[0] - return result + def execute(self, x): + return ReLUACL()(x) - def grad(self, grad_output): - grad_inputs = self.split_grad(grad_output, self.input, self.axis) - return grad_inputs + def relu(x): + return ReLUACL()(x) - def calculate_output_shape(self, input_tensors, axis): - shape = list(input_tensors[0].shape) - for tensor in input_tensors[1:]: - shape[axis] += tensor.shape[axis] - return tuple(shape) + from .aclops.relu_op import LeakyReLUACL - def split_grad(self, grad_output, input_tensors, axis): - offset = 0 - grad_inputs = [] - for tensor in input_tensors: - grad_input = acl_cmd("Slice", [ - grad_output, [0] * axis + [offset] + [0] * - (len(tensor.shape) - axis - 1), tensor.shape - ]) - grad_inputs.append(grad_input) - offset += tensor.shape[axis] - return grad_inputs + class LeakyReLU(jt.nn.Module): - class SetItemACL(Function): + def __init__(self, negative_slope=0.01): + super(LeakyReLU, self).__init__() + self.negative_slope = negative_slope + + def execute(self, x): + return LeakyReLUACL()(x, self.negative_slope) + + def leaky_relu(x, scale=0.01): + return LeakyReLUACL()(x, scale) + + from .aclops.dropout_op import DropoutACL + + class Dropout(jt.nn.Module): + + def __init__(self, p=0.5, is_train=False): + super(Dropout, self).__init__() + self.p = p + self.is_train = is_train + + def execute(self, x): + return DropoutACL()(x, self.p, self.is_train) + + def dropout_acl(x, p=0.5, is_train=False): + return DropoutACL()(x, p, is_train) + + from .aclops.silu_op import SiLUACL + + def silu_acl(x): + return SiLUACL()(x) + + class SiLU(jt.nn.Module): def __init__(self): - super(SetItemACL, self).__init__() + super(SiLU, self).__init__() - def stride(self, x, dim): - # 计算给定维度的步长 - stride = 1 - for i in range(dim + 1, len(x.shape)): - stride *= x.shape[i] - return stride + def execute(self, x): + return SiLUACL()(x) - def execute(self, x, slices, value, reduce='void'): - self.is_tensor = type(value) == jt.Var - if type(value) != jt.Var: - value = jt.array(value) - x_dim = len(x.shape) + from .aclops.sigmoid_op import SigmoidACL + + def sigmoid_acl(x): + return SigmoidACL()(x) - # 确保slices是一个元组 - if not isinstance(slices, tuple): - slices = (slices, ) + class Sigmoid(jt.nn.Module): - # 补齐slices使其长度等于x的维度 - if len(slices) < x_dim: - slices += (slice(None, None, None), ) * (x_dim - len(slices)) + def __init__(self): + super(Sigmoid, self).__init__() - self.inputs = [x, slices, value] + def execute(self, x): + return SigmoidACL()(x) - target_sizes = [] - target_strides = [] - offset = 0 + # class Embedding(jt.nn.Module): - for dim, s in enumerate(slices): - if isinstance(s, int): - if s < 0: - s += x.shape[dim] - s = slice(s, s + 1, None) - if isinstance(s, slice): - # 解包切片 - start, stop, step = s.indices(x.shape[dim]) - size = (stop - start - 1) // step + 1 - stride = self.stride(x, dim) * step - offset += start * self.stride(x, dim) - target_sizes.append(size) - target_strides.append(stride) + # def __init__(self, + # num_embeddings, + # embedding_dim, + # padding_idx=None, + # dtype="float32"): + # self.num_embeddings = num_embeddings + # self.embedding_dim = embedding_dim + # self.padding_idx = padding_idx + # self.weight = jt.init.gauss( + # [self.num_embeddings, self.embedding_dim], dtype) + # if padding_idx is not None: + # self.weight[padding_idx] = 0 + + # def execute(self, x): + # res = embedding_acl(x, self.weight) + # return res + + class Softmax(jt.nn.Module): + + def __init__(self): + super(Softmax, self).__init__() + + def execute(self, x, dim): + return SoftmaxACL()(x, dim) + + def softmax_acl(x, dim): + return SoftmaxACL()(x, dim) + + from .aclops.rope_op import RopeACL + def rope_acl(xq, xk, freqs_cis=None, freq_sin=None, freq_cos=None): + return RopeACL()(xq, xk, freqs_cis, freq_sin, freq_cos) + + from .aclops.stack_op import StackACL + def stack_acl(x, dim=0): + return StackACL()(x, dim) + + from .aclops.nantonum_op import NanToNumACL + + def isnan_acl(x): + tonum = NanToNumACL()(x, -1.0) + return jt.not_equal(x, tonum).logical_and( + jt.not_equal(tonum, jt.ones_like(x))) + + def isinf_acl(x): + tonum = NanToNumACL()(x, 1.0) + return jt.not_equal(x, tonum).logical_and( + jt.not_equal(tonum, jt.ones_like(x))) + + def warp(origin_func, new_func, name=None): + + if isinstance(origin_func, type): + + class WrappedClass(origin_func, new_func): + + def __init__(self, *args, **kwargs): + if jt.flags.use_acl: + new_func.__init__(self, *args, **kwargs) + else: + origin_func.__init__(self, *args, **kwargs) + + def execute(self, *args, **kwargs): + if jt.flags.use_acl: + return new_func.execute(self, *args, **kwargs) + elif name == 'setitem': + return args[0].assign(origin_func(*args, **kwargs)) + else: + return origin_func.execute(self, *args, **kwargs) + + return WrappedClass + + else: + + def warpper(*args, **kwargs): + if jt.flags.use_acl: + return new_func(*args, **kwargs) + elif name == 'setitem': + return args[0].assign(origin_func(*args, **kwargs)) else: - print("slices: ", s, type(s)) - raise ValueError("Invalid slice type") + return origin_func(*args, **kwargs) - # 计算value的size、stride和offset - value_sizes = list(value.shape) - value_strides = [ - self.stride(value, dim) for dim in range(len(value.shape)) - ] + return warpper - self.target_sizes = target_sizes - self.target_strides = target_strides - self.offset = offset - self.value_sizes = value_sizes - self.value_strides = value_strides + jt.triu = warp(jt.triu, triu_acl) + jt.triu_ = warp(jt.triu, triu_acl) + jt.Var.triu = jt.triu + jt.Var.triu_ = lambda x, diagonal=0: x.assign(x.triu(diagonal)) + jt.nn.conv2d = warp(jt.nn.conv2d, conv_acl) + jt.nn.Conv2d = warp(jt.nn.Conv2d, Conv2D) + jt.nn.Conv = warp(jt.nn.Conv, Conv2D) - #import pdb; pdb.set_trace() - result = acl_cmd("ViewCopy", [ - x, - jt.Var(target_sizes), - jt.Var(target_strides), - jt.Var(offset), value, - jt.Var(value_sizes), - jt.Var(value_strides), - jt.Var(0) - ], - output_dtypes=[x.dtype], - output_shapes=[x.shape], - attr={})[0] - result.sync() - return result - - def grad(self, grad_output): - result = acl_cmd("AsStrided", [ - grad_output, - jt.Var(self.target_sizes), - jt.Var(self.target_strides), - jt.Var(self.offset) - ], - output_dtypes=[grad_output.dtype], - output_shapes=[jt.empty(self.target_sizes).shape], - attr={})[0] - # copy grad_output to new_grad_output - new_grad_output = acl_cmd("Copy", [grad_output], - output_dtypes=[grad_output.dtype], - output_shapes=[grad_output.shape], - attr={"N": 1})[0] - new_grad_output = acl_cmd("ViewCopy", [ - new_grad_output, - jt.Var(self.target_sizes), - jt.Var(self.target_strides), - jt.Var(self.offset), - jt.zeros(self.value_sizes, dtype=grad_output.dtype), - jt.Var(self.value_sizes), - jt.Var(self.value_strides), - jt.Var(0) - ], - output_dtypes=[grad_output.dtype], - output_shapes=[grad_output.shape], - attr={})[0] - new_grad_output.sync() - return new_grad_output, None, result if self.is_tensor else None - - class TriuACL(Function): - - def __init__(self): - super(TriuACL, self).__init__() - - def execute(self, input, k): - self.input = input - result = acl_cmd("Triu", [input], - output_dtypes=[input.dtype], - output_shapes=[input.shape], - attr={'diagonal': k})[0] - return result - - def grad(self, grad_output): - return grad_output - - class TransposeACL(Function): - - def __init__(self): - super(TransposeACL, self).__init__() - - def execute(self, input, perm): - self.input = input - - output_shape = input.shape[perm[0]:perm[0] + 1] - - for i in range(1, len(perm)): - output_shape += input.shape[perm[i]:perm[i] + 1] - result = acl_cmd("Transpose", [input, jt.Var(perm)], - output_dtypes=[input.dtype], - output_shapes=[output_shape], - attr={})[0] - return result - - def grad(self, grad_output): - return grad_output - - class AdaptiveMaxPool2dACL(Function): - - def __init__( - self, - output_size, - return_indices=False, - ): - super(AdaptiveMaxPool2dACL, self).__init__() - self.output_size = (output_size, output_size) if isinstance( - output_size, int) else output_size - - self.return_indices = return_indices - self.uint16 = jt.Var(1).int32().dtype - - attr = {} - attr['ceil_mode'] = False - attr['dilations'] = [1, 1, 1, 1] - self.attr = attr - - def execute(self, input): - input_shape = input.shape - input_dtype = input.dtype - - output_shape = [ - input_shape[0], input_shape[1], self.output_size[0], - self.output_size[1] - ] - output_dtype = input_dtype - self.input = input - - stride_h = input_shape[2] // output_shape[2] - stride_w = input_shape[3] // output_shape[3] - kernel_size_h = input_shape[2] - (output_shape[2] - 1) * stride_h - kernel_size_w = input_shape[3] - (output_shape[3] - 1) * stride_w - - stride = [0, 0] - kernel_size = [0, 0] - padding = [0, 0] - - stride[0] = stride_h - stride[1] = stride_w - kernel_size[0] = kernel_size_h - kernel_size[1] = kernel_size_w - padding[0] = padding[1] = 0 - kernel_sizes = [1, kernel_size[0], kernel_size[1], 1] - strides_size = [1, stride[0], stride[1], 1] - paddings = [1, padding[0], padding[1], 1] - - self.attr['ksize'] = kernel_sizes - self.attr['strides'] = strides_size - self.attr['pads'] = paddings - - result = acl_cmd("MaxPoolWithArgmaxV1", [input], - output_dtypes=[output_dtype, self.uint16], - output_shapes=[output_shape, output_shape], - attr=self.attr) - - self.index = result[1] - - if self.return_indices: - return result[0], result[1] - else: - return result[0] - - def grad(self, grad_output): - grad_input = acl_cmd("MaxPoolGradWithArgmaxV1", - [self.input, grad_output, self.index], - output_dtypes=[grad_output.dtype], - output_shapes=[self.input.shape], - attr=self.attr)[0] - return grad_input - - class AdaptiveAvgPool2dACL(Function): - - def __init__(self, output_size): - super(AdaptiveAvgPool2dACL, self).__init__() - self.output_size = (output_size, output_size) if isinstance( - output_size, int) else output_size - - attr = {} - if isinstance(output_size, tuple): - output_size = [output_size[0], output_size[1]] - attr['output_size'] = output_size - self.attr = attr - - def execute(self, input): - input_shape = input.shape - input_dtype = input.dtype - - self.original_shape = input_shape - - output_shape = [ - input_shape[0], input_shape[1], self.attr['output_size'][0], - self.attr['output_size'][1] - ] - output_dtype = input_dtype - self.input = input - - result = acl_cmd("AdaptiveAvgPool2d", [input], - output_dtypes=[output_dtype], - output_shapes=[output_shape], - attr=self.attr) - - return result[0] - - def grad(self, grad_output): - attr = {} - attr['orig_input_shape'] = list(self.original_shape) - grad_input = acl_cmd("AdaptiveAvgPool2dGrad", [grad_output], - output_dtypes=[grad_output.dtype], - output_shapes=[self.original_shape], - attr=attr)[0] - return grad_input - - class CumsumACL(Function): - - def __init__(self): - super(CumsumACL, self).__init__() - - def execute(self, input, dim=-1): - self.input = input - self.dim = dim - result = acl_cmd("Cumsum", [input, jt.Var(dim)], - output_dtypes=[input.dtype], - output_shapes=[input.shape], - attr={})[0] - return result - - def grad(self, grad_output): - flipped_grad_output = acl_cmd( - "ReverseV2", [grad_output, jt.Var([self.dim])], - output_dtypes=[grad_output.dtype], - output_shapes=[grad_output.shape], - attr={})[0] - cumulative_grad = acl_cmd( - "Cumsum", - [flipped_grad_output, jt.Var(self.dim)], - output_dtypes=[grad_output.dtype], - output_shapes=[grad_output.shape], - attr={})[0] - grad_input = acl_cmd( - "ReverseV2", - [cumulative_grad, jt.Var([self.dim])], - output_dtypes=[grad_output.dtype], - output_shapes=[grad_output.shape], - attr={})[0] - return grad_input - - class GatherACL(Function): - - def __init__(self): - super(GatherACL, self).__init__() - - def execute(self, input, dim, index): - self.input = input - self.dim = dim - self.index = index - - result = acl_cmd("GatherElements", [input, index], - output_dtypes=[input.dtype], - output_shapes=[index.shape], - attr={'dim': dim})[0] - return result - - def grad(self, grad_output): - tmp = jt.zeros(self.index.shape, dtype=grad_output.dtype) - grad_input = acl_cmd("ScatterElements", - [tmp, self.index, grad_output], - output_dtypes=[grad_output.dtype], - output_shapes=[tmp.shape], - attr={ - 'axis': self.dim, - 'reduction': "add" - })[0] - return grad_input - - class ScatterACL(Function): - - def __init__(self): - super(ScatterACL, self).__init__() - - def execute(self, input, dim, index, src, reduce='void'): - self.input = input - self.dim = dim - self.index = index - self.reduce = reduce - result = acl_cmd("ScatterElements", [input, self.index, src], - output_dtypes=[input.dtype], - output_shapes=[input.shape], - attr={ - 'axis': self.dim, - 'reduction': reduce - })[0] - return result - - def grad(self, grad_output): - grad_input = acl_cmd("GatherElements", [grad_output, self.index], - output_dtypes=[grad_output.dtype], - output_shapes=[self.index.shape], - attr={'dim': self.dim})[0] - return grad_output, None, None, grad_input - - class WhereACL(Function): - - def __init__(self): - super(WhereACL, self).__init__() - - def execute(self, condition, x, y): - self.condition = condition - - if x.dtype != y.dtype: - if x.dtype == jt.float32: - y = y.float32() - elif y.dtype == jt.float32: - x = x.float32() - else: - x = x.to(y.dtype) - - self.x = x - self.y = y - - result = acl_cmd("Select", [condition, x, y], - output_dtypes=[x.dtype], - output_shapes=[x.shape], - attr={})[0] - return result - - def grad(self, grad_output): - tmp = jt.zeros(grad_output.shape, dtype=grad_output.dtype) - grad_x = acl_cmd("Select", [self.condition, grad_output, tmp], - output_dtypes=[self.x.dtype], - output_shapes=[self.x.shape], - attr={})[0] - - grad_y = acl_cmd("Select", [self.condition, tmp, grad_output], - output_dtypes=[self.y.dtype], - output_shapes=[self.y.shape], - attr={})[0] - return grad_output, grad_x, grad_y - - class FlipACL(Function): - - def __init__(self): - super(FlipACL, self).__init__() - - def execute(self, input, dim): - self.input = input - #if isinstance(dim_vector, tuple): - dim_vector = jt.Var(list(dim)) - #print(dim_vector.dtype) - self.dim_vector = dim_vector - #print(input, dim_vector) - result = acl_cmd("ReverseV2", [input, dim_vector], - output_dtypes=[input.dtype], - output_shapes=[input.shape], - attr={})[0] - return result - - def grad(self, grad_output): - #print(grad_output) - grad_input = acl_cmd("ReverseV2", [grad_output, self.dim_vector], - output_dtypes=[grad_output.dtype], - output_shapes=[grad_output.shape], - attr={})[0] - return grad_input - - class FloorIntACL(Function): - - def __init__(self): - super(FloorIntACL, self).__init__() - - def execute(self, input): - self.input = input - self.shape = input.shape - result = acl_cmd("Floor", [input], - output_dtypes=[jt.int], - output_shapes=[input.shape], - attr={})[0] - return result - - def grad(self, grad_output): - return jt.zeros(self.shape, dtype=grad_output.dtype) - - def warp(origin_func, new_func): - - def warpper(*args, **kwargs): - if origin_func == jt.index: - if len(args) == 2 and args[1] == None: - args = tuple(list(args[0:1])) - if jt.flags.use_acl: - if isinstance(new_func, IndexACL): - if len(args) == 1: - args = (args[0], None) - if isinstance(new_func, CumsumACL): - args = (args[0], kwargs.get('dim', -1)) - kwargs = {} - if isinstance(new_func, - ScatterACL) and kwargs.get('reduce') is not None: - args = (args[0], args[1], args[2], args[3], - kwargs.get('reduce', 'void')) - kwargs = {} - - return new_func(*args, **kwargs) - return origin_func(*args, **kwargs) - - return warpper - - jt.index = warp(jt.index, IndexACL()) - jt.Var.index = lambda x, dim=None: warp(jt.index, IndexACL())(x.shape, dim) jt.nn.Pool = warp(jt.nn.Pool, PoolACL) - jt.nn.AdaptiveMaxPool2d = warp(jt.nn.AdaptiveMaxPool2d, - AdaptiveMaxPool2dACL) - jt.nn.AdaptiveAvgPool2d = warp(jt.nn.AdaptiveAvgPool2d, - AdaptiveAvgPool2dACL) - jt.triu = warp(jt.triu, TriuACL()) - jt.triu_ = warp(jt.triu, TriuACL()) - jt.Var.triu = lambda x: warp(jt.Var.triu, TriuACL())(x) - jt.Var.triu_ = lambda x: warp(jt.Var.triu_, TriuACL())(x) + jt.flip = warp(jt.flip, flip_acl) + jt.Var.flip = lambda x, dim_vector=0: jt.flip(x, dim_vector) + jt.concat = warp(jt.concat, concat) + jt.stack = warp(jt.stack, stack_acl) - jt.getitem = warp(jt.getitem, GetItem()) + jt.gather = warp(jt.gather, gather_acl) + jt.any = warp(jt.any, any_acl) + jt.Var.any = jt.any + + jt.cumsum = warp(jt.cumsum, cumsum_acl) + jt.cub_cumsum = jt.cumsum + jt.Var.cumsum = jt.cumsum + jt.Var.cub_cumsum = jt.cumsum + + jt.cumprod = warp(jt.cumprod, cumprod_acl) + jt.index = warp(jt.index, index_acl) + jt.Var.index = jt.index + + jt.scatter = warp(jt.scatter, scatter_acl) + jt.Var.scatter = lambda x, dim, index, src, reduce="void": jt.scatter( + x, dim, index, src, reduce) + + jt.where = warp(jt.where, where_acl) + jt.nonzero = warp(jt.nonzero, nonzero_acl) + jt.misc.nonzero = warp(jt.misc.nonzero, nonzero_acl) + jt.Var.nonzero = jt.misc.nonzero + jt.floor_int = warp(jt.floor_int, floor_int_acl) + jt.Var.floor_int = lambda x: jt.floor_int(x) + + jt.getitem = warp(jt.contrib.getitem, getitem_acl) + fake_getitem = jt.Var.getitem jt.Var.getitem = lambda x, slices, return_x=None: warp( - jt.getitem, GetItem())(x, slices) + fake_getitem, getitem_acl)(x, slices) + jt.Var.slice_var = lambda x, slices, return_x=None: warp( + fake_getitem, getitem_acl)(x, slices) + jt.Var.__getitem__ = lambda x, slices, return_x=None: warp( + fake_getitem, getitem_acl)(x, slices) - jt.setitem = warp(jt.setitem, SetItemACL()) - jt.Var.setitem = lambda x, slices, value, reduce='void': warp( - jt.setitem, SetItemACL())(x, slices, value, reduce) + jt.setitem = warp(jt.contrib.setitem, setitem_acl) + fake_setitem = jt.Var.setitem + jt.Var.setitem = lambda x, slices, value: warp( + fake_setitem, setitem_acl, name='setitem')(x, slices, value) + jt.Var.__setitem__ = lambda x, slices, value: warp( + fake_setitem, setitem_acl, name='setitem')(x, slices, value) - jt.misc.flip = warp(jt.misc.flip, FlipACL()) - jt.Var.flip = lambda x, dim_vector: warp(jt.misc.flip, FlipACL())( - x, dim_vector) - jt.cumsum = warp(jt.cumsum, CumsumACL()) - jt.gather = warp(jt.gather, GatherACL()) - jt.Var.gather = lambda x, dim, index: warp(jt.gather, GatherACL())(x, dim, - index) - jt.scatter = warp(jt.scatter, ScatterACL()) - jt.Var.scatter = lambda x, dim, index, src, reduce="void": warp( - jt.scatter, ScatterACL())(x, dim, index, src, reduce) - jt.where = warp(jt.where, WhereACL()) - jt.floor_int = warp(jt.floor_int, FloorIntACL()) - jt.Var.floor_int = lambda x: warp(jt.floor_int, FloorIntACL())(x) + fake_matmul = jt.Var.matmul + jt.nn.bmm = warp(jt.nn.bmm, bmm_acl) + jt.bmm = warp(jt.bmm, bmm_acl) + jt.nn.matmul = warp(jt.matmul, matmul_acl) + jt.matmul = warp(jt.matmul, matmul_acl) + jt.nn.matmul_transpose = warp(jt.nn.matmul_transpose, matmul_transpose_acl) + jt.nn.bmm_transpose = warp(jt.nn.bmm_transpose, bmm_transpose_acl) + jt.bmm_transpose = warp(jt.bmm_transpose, bmm_transpose_acl) + jt.Var.__matmul__ = lambda x, y: warp(fake_matmul, matmul_acl)(x, y) - # jt.nn.bmm = warp(jt.nn.bmm, BmmACL()) - # jt.bmm = warp(jt.bmm, BmmACL()) - # jt.nn.matmul = warp(jt.matmul, MatmulACL()) - # jt.matmul = warp(jt.matmul, MatmulACL()) - # jt.transpose = warp(jt.transpose, TransposeACL()) - # jt.Var.transpose = lambda x, perm: warp(jt.transpose, TransposeACL())(x, perm) - # jt.concat = warp(jt.concat, ConcatACL()) + jt.transpose = warp(jt.transpose, transpose_acl) + fake_transpose = jt.transpose + jt.Var.transpose = lambda x, *dim: warp(fake_transpose, transpose_acl)(x, * + dim) + # jt.Var.permute = lambda x: warp(fake_transpose, transpose_acl)(x) + # jt.Var.t = lambda x: warp(fake_transpose, transpose_acl)(x) + + jt.nn.relu = warp(jt.nn.relu, relu) + jt.nn.ReLU = warp(jt.nn.ReLU, ReLU) + + jt.nn.leaky_relu = warp(jt.nn.leaky_relu, leaky_relu) + jt.nn.LeakyReLU = warp(jt.nn.LeakyReLU, LeakyReLU) + + # jt.nn.silu = warp(jt.nn.silu, silu_acl) + # jt.nn.SiLU = warp(jt.nn.SiLU, SiLU) + + jt.sigmoid = warp(jt.sigmoid, sigmoid_acl) + jt.nn.Sigmoid = warp(jt.nn.Sigmoid, Sigmoid) + + # from .aclops.embedding_op import EmbeddingACL + # def embedding_acl(indices, weight): + # return EmbeddingACL()(indices, weight) + + # jt.nn.embedding = warp(jt.nn.embedding, embedding_acl) + # jt.nn.Embedding = warp(jt.nn.Embedding, Embedding) + jt.nn.dropout = warp(jt.nn.dropout, dropout_acl) + jt.nn.Dropout = warp(jt.nn.Dropout, Dropout) + + jt.nn.softmax = warp(jt.nn.softmax, softmax_acl) + + # from .aclops.norms_op import BatchNormACL,LayerNormACL + # jt.nn.BatchNorm = warp(jt.nn.BatchNorm, BatchNormACL) + # jt.nn.LayerNorm = warp(jt.nn.LayerNorm, LayerNormACL) + + jt.nn.FlashAttention = warp(jt.nn.FlashAttention, FlashAttentionACL) + jt.isnan = warp(jt.isnan, isnan_acl) + jt.isinf = warp(jt.isinf, isinf_acl) + jt.Var.isnan = jt.isnan + jt.Var.isinf = jt.isinf + + jt.nn.rotary_emb = rope_acl \ No newline at end of file diff --git a/python/jittor/extern/acl/acl_error_code.cc b/python/jittor/extern/acl/acl_error_code.cc index 04df8c20..5fd45dbf 100644 --- a/python/jittor/extern/acl/acl_error_code.cc +++ b/python/jittor/extern/acl/acl_error_code.cc @@ -1,6 +1,6 @@ // *************************************************************** -// Copyright (c) 2023 Jittor. All Rights Reserved. -// Maintainers: Dun Liang . +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . // This file is subject to the terms and conditions defined in // file 'LICENSE.txt', which is part of this source code package. // *************************************************************** @@ -11,23 +11,27 @@ using std::unordered_map; typedef int aclError; -static inline unordered_map gen_map(string s) { - unordered_map smap; - for (int i=0; i gen_map(string s) +{ + unordered_map smap; + for (int i = 0; i < s.size(); i++) + { + if (s[i] == ';') + { + int j = s.rfind(" ", i); + int code = std::stoi(s.substr(j + 1, i - j - 1)); + int k = s.rfind(" ", j - 1); + int l = s.rfind(" ACL_", k - 1); + smap[code] = s.substr(l + 1, k - l - 1); } } return smap; } -string acl_error_to_string(aclError error) { +string acl_error_to_string(aclError error) +{ -static unordered_map acl_error_map = gen_map(R"( + static unordered_map acl_error_map = gen_map(R"( // from acl_base.h static const int ACL_ERROR_INVALID_PARAM = 100000; static const int ACL_ERROR_UNINITIALIZE = 100001; diff --git a/python/jittor/extern/acl/acl_jittor.cc b/python/jittor/extern/acl/acl_jittor.cc index be0c17f4..6e184020 100644 --- a/python/jittor/extern/acl/acl_jittor.cc +++ b/python/jittor/extern/acl/acl_jittor.cc @@ -1,6 +1,6 @@ // *************************************************************** -// Copyright (c) 2023 Jittor. All Rights Reserved. -// Maintainers: Dun Liang . +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . // This file is subject to the terms and conditions defined in // file 'LICENSE.txt', which is part of this source code package. // *************************************************************** @@ -10,267 +10,311 @@ #include "utils/str_utils.h" #include #include +#include "aclnn/aclnn.h" -namespace jittor { +namespace jittor +{ -uint64_t acl_jittor_tid; -int acl_jittor_thread_running=0; -aclrtContext acl_jittor_context; -aclrtStream aclstream; + uint64_t acl_jittor_tid; + int acl_jittor_thread_running = 0; + aclrtStream aclstream; + void *workspaceAddr = nullptr; + uint64_t nowWorkSpaceSize = 0; -#define CHECK_ACL(x) ASSERTop(x,==,0) +#define CHECK_ACL(x) ASSERTop(x, ==, 0) -static void* acl_jittor_process_callback(void*) { - acl_jittor_thread_running = 1; - int deviceId = 0; - CHECK_ACL(aclrtSetCurrentContext(acl_jittor_context)); - - while (acl_jittor_thread_running) { - // LOGir << "acl_jittor_process_callback"; - auto ret = aclrtProcessReport(1000); - if (ret) { - if (acl_jittor_thread_running && ret != ACL_ERROR_RT_REPORT_TIMEOUT && ret != ACL_ERROR_RT_THREAD_SUBSCRIBE) - LOGir << "aclrtProcessReport:" << ret << acl_error_to_string(ret); - break; + void mallocWorkSpace(uint64_t size) + { + uint64_t alloc_size = size + 32; + alloc_size = ((alloc_size - 1) / 32 + 1) * 32; + if (alloc_size > nowWorkSpaceSize) + { + aclrtFree(workspaceAddr); + nowWorkSpaceSize = alloc_size; + auto ret = aclrtMalloc(&workspaceAddr, nowWorkSpaceSize, ACL_MEM_MALLOC_HUGE_FIRST); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("allocate workspace failed. ERROR: %d\n", ret); return); } } - acl_jittor_thread_running = 0; - return (void*)0; -} + static void *acl_jittor_process_callback(void *) + { + acl_jittor_thread_running = 1; -// void aaa(void*) { -// LOGir << "haha"; -// } - -struct acl_jittor_initer { - -acl_jittor_initer() { - CHECK_ACL(aclInit(nullptr)); - uint device_count = 0; - // 获取可用的Device数量 - CHECK_ACL(aclrtGetDeviceCount(&device_count)); - LOGi << "Found ACL device number:" << device_count; - CHECK_ACL(aclrtSetDevice(0)); - CHECK_ACL(aclrtCreateContext(&acl_jittor_context, 0)); - CHECK_ACL(aclrtSetCurrentContext(acl_jittor_context)); - - pthread_create(&acl_jittor_tid, nullptr, acl_jittor_process_callback, 0); - - // subscribe for default stream - CHECK_ACL(aclrtSubscribeReport(acl_jittor_tid,0)); - - // simple callback test - CHECK_ACL(aclrtCreateStream(&aclstream)); - // CHECK_ACL(aclrtSubscribeReport(acl_jittor_tid,aclstream)); - // CHECK_ACL(aclrtLaunchCallback((aclrtCallback)&aaa, 0, ACL_CALLBACK_NO_BLOCK, aclstream)); - // CHECK_ACL(aclrtLaunchCallback((aclrtCallback)&aaa, 0, ACL_CALLBACK_NO_BLOCK, 0)); -} - -~acl_jittor_initer() { - acl_jittor_thread_running = 0; - CHECK_ACL(aclrtUnSubscribeReport(acl_jittor_tid,0)); - CHECK_ACL(aclrtDestroyContext(acl_jittor_context)); - CHECK_ACL(aclFinalize()); -} - -} _acl_jittor_initer; - -string process_acl(const string& src, const string& name, const map& kargs) { - if (endswith(name, "_jittor.cc")) - return src; - // static vector dont_compile = {"fp16_emu.cc"}; - // for (auto& s : dont_compile) - // if (endswith(name, s)) - // return " "; - static unordered_set cuda_headers = { - "cuda_runtime", "cudnn", "driver_types", - "cuda_fp16", "cuda_runtime_api", "fp16_emu", - "cudnn_rnn_descriptor", "cublas_v2", "cublas_wrapper", - "curand", "curand_wrapper", "cufft", "cufftXt", - "CudaUtils", "cutt", "cudnn_wrapper", "cuda_bf16" - }; - static unordered_set fake_class = { - "cudnnHandle_t", "cudnnConvolutionBwdFilterAlgo_t", - "cudnnConvolutionBwdDataAlgo_t", "cudnnConvolutionFwdAlgo_t", - "cufftHandle" - }; - try { - auto tokens = token_split(src); - int edit = 0; - for (int i=0; i=5 && token[4] >= 'A' && token[4] <= 'Z') { - if (token == "cudaGetDeviceCount") { - token_replace(tokens, i, "($1);", "((uint*)$1);"); - } else if (token == "cudaLaunchHostFunc") { - // ACL_CALLBACK_BLOCK for 310 - token_replace(tokens, i, "LaunchHostFunc($1,$2,$3)", - "LaunchCallback($2,$3,ACL_CALLBACK_NO_BLOCK,$1)"); - } else if (token == "cudaMemcpy") - token_replace(tokens, i, "cudaMemcpy($1,$2,$3,", - "aclrtMemcpy($1,$3,$2,$3,"); - else if (token == "cudaMemcpyAsync") - token_replace(tokens, i, "cudaMemcpyAsync($1,$2,$3,", - "aclrtMemcpyAsync($1,$3,$2,$3,"); - else if (token == "cudaMemcpyDeviceToHost") token = "ACL_MEMCPY_DEVICE_TO_HOST"; - else if (token == "cudaMemcpyDefault") token = "ACL_MEMCPY_HOST_TO_DEVICE"; - else if (token == "cudaMemcpyHostToDevice") token = "ACL_MEMCPY_HOST_TO_DEVICE"; - else if (token == "cudaMemcpyDeviceToDevice") token = "ACL_MEMCPY_DEVICE_TO_DEVICE"; - else if (token == "cudaMallocManaged" || token == "cudaMalloc") { - // unified address not supported - token = "aclrtMalloc"; - token_replace(tokens, i, "($1,$2)", - "($1,$2,ACL_MEM_MALLOC_HUGE_FIRST)"); - } else if (token == "cudaMemGetInfo") - token_replace(tokens, i, "cudaMemGetInfo($1,$2)", - "aclrtGetMemInfo(ACL_DDR_MEM,$1,$2)"); - else if (token == "cudaGetLastError") - token_replace(tokens, i, "cudaGetLastError()", "0"); - else if (token == "cudaStreamCreateWithFlags") - token_replace(tokens, i-1, - "(cudaStreamCreateWithFlags($1,$2));", - "(aclrtCreateStream($1)); checkAclErrors(aclrtSubscribeReport(acl_jittor_tid,*$1));"); - else if (token == "cudaEventCreate") - token_replace(tokens, i, - "cudaEventCreate($1,$2)", - "aclrtCreateEvent($1)"); - else if (token == "cudaDeviceSynchronize") - token = "aclrtSynchronizeDevice"; - else if (token == "cudaStreamDestroy") - token_replace(tokens, i, "cudaStreamDestroy($1)", - "(aclrtUnSubscribeReport(acl_jittor_tid,$1), aclrtDestroyStream($1))"); - else if (token == "cudaEventDestroy") - token = "aclrtDestroyEvent"; - else if (token == "cudaEventRecord") - token = "aclrtRecordEvent"; - else if (token == "cudaStreamWaitEvent") - token_replace(tokens, i, - "cudaStreamWaitEvent($1,$2,$3)", - "aclrtStreamWaitEvent($1,$2)"); - - if (token.size() && token[0] == 'c') - token = "aclrt" + token.substr(4); - if (endswith(token, "_t")) - token = token.substr(0, token.size()-2); - edit ++; - } - } else - if (token == "_cudaGetErrorEnum") { - token_replace(tokens, i, "_cudaGetErrorEnum($1)", "(acl_error_to_string($1))"); - edit ++; - } else - if (token == "checkCudaErrors") - token = "checkAclErrors"; - else if (token == "JPU") { - edit ++; - string new_code; - if (tokens[i+2] == "op_compiler") - token_replace(tokens, i, - "JPU(op_compiler($1,$2,$3))", - "acl_jittor_op_compiler($1,$2,$3)"); - else if (tokens[i+2] == "header") - new_code = "#include \"acl_jittor.h\""; - if (new_code.size()) - token_replace(tokens, i, "JPU($1)", new_code); - } else if (token == "use_cuda_managed_allocator" && tokens[i+1][0]==',') { - tokens[i+2] = "0"; // disable unified address - } - } - if (!edit) return src; - string new_src = join(tokens, ""); - // if (name == "executor.cc") { - // new_src = string("#include \n#include \n#include \n")+ - // "namespace jittor { void acl_op_exec(Op*); }\n" + - // replace(new_src, "op->do_run_after_prepare(jkl);", - // R"({ - // acl_op_exec(op); - // })"); - // } - if (name == "profiler.cc") { - new_src = token_replace_all(new_src, ".cc", ".tikcc"); - } - // LOGir << name << (name == "pass_manager.cc"); - if (name == "pass_manager.cc") { - LOGir << "replace" << name; - new_src = token_replace_all(new_src, "run_pass();", "WTF"); - } - // ???????? - return new_src; - } catch (const std::exception& e) { - LOGe << "process acl error:" << e.what(); - LOGe << "name:" << name; - throw; - } -} - -void acl_jittor_op_compiler(string& filename, string& src, bool is_acl, string& extra_flags) { - if (!is_acl) return; - // extra_flags += " --tik-soc-version=Ascend910 "; - // filename = replace(filename, ".cc", ".tikcc"); - // LOGir << filename; - string new_src = process_acl(src, "", {}); - new_src = replace(new_src, R"(#include "misc/cuda_atomic.h")", ""); - new_src = replace(new_src, R"(#include "misc/cuda_limits.h")", ""); - new_src = replace(new_src, "__global__", "__ai_device_entry__"); - new_src = token_replace_all(new_src, "__launch_bounds__($1)", ""); - new_src = token_replace_all(new_src, "int thread_num = $1;", "int thread_num = 1;"); - new_src = token_replace_all(new_src, "tn0=std::max(tn0, $1);", ""); - new_src = token_replace_all(new_src, "<<<$1>>>", "<<<1,0>>>"); - new_src = token_replace_all(new_src, "int thread_id = $1;", "int thread_id = 1;"); - // for inc error - new_src = token_replace_all(new_src, "for ($1+=$2)", "for ($1++)"); - // bit op error - new_src = token_replace_all(new_src, "int tnum$1;", ""); - new_src = token_replace_all(new_src, "int p1$1;", ""); - new_src = token_replace_all(new_src, "int p2$1;", ""); - new_src = token_replace_all(new_src, "int tn$1=$2;", "int tn$1=0;"); - new_src = token_replace_all(new_src, "int tid$1=$2;", "int tid$1=0;"); - src = new_src; - - new_src = token_replace_all(new_src, "atomicAdd(&$1,$2);", "$1=$1+$2;"); - // new_src = token_replace_all(new_src, "bool", "int8"); - new_src = token_replace_all(new_src, "::numeric_min()", "-1e30"); - new_src = token_replace_all(new_src, "::numeric_max()", "1e30"); - // TODO: support max - unordered_map opmap = { - // {"::max","tikcc::scalar_max"}, - {"::sqrtf", "tikcc::scalar_sqrt"} - }; - auto ss = split(new_src, ";"); - for (auto &s : ss) { - if (s.find("?") != string::npos) { - s = token_replace_all(s+";", "auto $1=$2?$3:$4;", "auto $1=$3;if (!($2)) $1=$4;"); - } - if (s.find("::max") != string::npos) { - if (s.find("auto") == string::npos) { - s = token_replace_all(s+";", " $1=$4::max($2,$3);", " $1=$2;if ($2 < $3) $1=$3;"); - } else { - s = token_replace_all(s+";", "auto $1=$4::max($2,$3);", "auto $1=$2;if ($2 < $3) $1=$3;"); + while (acl_jittor_thread_running) + { + // LOGir << "acl_jittor_process_callback"; + auto ret = aclrtProcessReport(1000); + if (ret) + { + if (acl_jittor_thread_running && ret != ACL_ERROR_RT_REPORT_TIMEOUT && ret != ACL_ERROR_RT_THREAD_SUBSCRIBE) + LOGir << "aclrtProcessReport:" << ret << acl_error_to_string(ret); + break; } } - for (auto& kv : opmap) { - if (s.find(kv.first) != string::npos) { - if (s.find("auto") == string::npos) { - // $1 = op($2) --> op($1, $2) - s = token_replace_all(s+";", " $1= "+kv.first+"($2);", kv.second+"($1, $2);"); - } else { - // auto $1 = op($2) --> float32 $1; op($1, $2); - s = token_replace_all(s+";", "auto $1= "+kv.first+"($2);", "float32 $1; " + kv.second+"($1, $2);"); + acl_jittor_thread_running = 0; + return (void *)0; + } + + struct acl_jittor_initer + { + int32_t deviceId; + acl_jittor_initer() + { + CHECK_ACL(aclInit(nullptr)); + uint device_count = 0; + deviceId = 0; + // 获取可用的Device数量 + CHECK_ACL(aclrtGetDeviceCount(&device_count)); + LOGi << "Found ACL device number:" << device_count; + CHECK_ACL(aclrtSetDevice(deviceId)); + CHECK_ACL(aclrtCreateStream(&aclstream)); + // pthread_create(&acl_jittor_tid, nullptr, acl_jittor_process_callback, 0); + } + + ~acl_jittor_initer() + { + acl_jittor_thread_running = 0; + // CHECK_ACL(aclrtUnSubscribeReport(acl_jittor_tid, 0)); + aclrtDestroyStream(aclstream); + aclrtResetDevice(deviceId); + CHECK_ACL(aclFinalize()); + if (nowWorkSpaceSize > 0) + { + aclrtFree(workspaceAddr); + } + } + + } _acl_jittor_initer; + + string process_acl(const string &src, const string &name, const map &kargs) + { + if (endswith(name, "_jittor.cc")) + return src; + // static vector dont_compile = {"fp16_emu.cc"}; + // for (auto& s : dont_compile) + // if (endswith(name, s)) + // return " "; + static unordered_set cuda_headers = { + "cuda_runtime", "cudnn", "driver_types", + "cuda_fp16", "cuda_runtime_api", "fp16_emu", + "cudnn_rnn_descriptor", "cublas_v2", "cublas_wrapper", + "curand", "curand_wrapper", "cufft", "cufftXt", + "CudaUtils", "cutt", "cudnn_wrapper", "cuda_bf16"}; + static unordered_set fake_class = { + "cudnnHandle_t", "cudnnConvolutionBwdFilterAlgo_t", + "cudnnConvolutionBwdDataAlgo_t", "cudnnConvolutionFwdAlgo_t", + "cufftHandle"}; + try + { + auto tokens = token_split(src); + int edit = 0; + for (int i = 0; i < tokens.size(); i++) + { + auto &token = tokens[i]; + if (cuda_headers.count(token)) + token = "acl_jittor", edit++; + else if (fake_class.count(token)) + token = "int", edit++; + else if (token == "CUDA") + token = "ACL", edit++; + else if (startswith(token, "cuda")) + { + if (token.size() >= 5 && token[4] >= 'A' && token[4] <= 'Z') + { + if (token == "cudaGetDeviceCount") + { + token_replace(tokens, i, "($1);", "((uint*)$1);"); + } + else if (token == "cudaLaunchHostFunc") + { + // ACL_CALLBACK_BLOCK for 310 + token_replace(tokens, i, "LaunchHostFunc($1,$2,$3)", + "LaunchCallback($2,$3,ACL_CALLBACK_NO_BLOCK,$1)"); + } + else if (token == "cudaMemcpy") + token_replace(tokens, i, "cudaMemcpy($1,$2,$3,", + "aclrtMemcpy($1,$3,$2,$3,"); + else if (token == "cudaMemcpyAsync") + token_replace(tokens, i, "cudaMemcpyAsync($1,$2,$3,", + "aclrtMemcpyAsync($1,$3,$2,$3,"); + else if (token == "cudaMemcpyDeviceToHost") + token = "ACL_MEMCPY_DEVICE_TO_HOST"; + else if (token == "cudaMemcpyDefault") + token = "ACL_MEMCPY_HOST_TO_DEVICE"; + else if (token == "cudaMemcpyHostToDevice") + token = "ACL_MEMCPY_HOST_TO_DEVICE"; + else if (token == "cudaMemcpyDeviceToDevice") + token = "ACL_MEMCPY_DEVICE_TO_DEVICE"; + else if (token == "cudaMallocManaged" || token == "cudaMalloc") + { + // unified address not supported + token = "aclrtMalloc"; + token_replace(tokens, i, "($1,$2)", + "($1,$2,ACL_MEM_MALLOC_HUGE_FIRST)"); + } + else if (token == "cudaMemGetInfo") + token_replace(tokens, i, "cudaMemGetInfo($1,$2)", + "aclrtGetMemInfo(ACL_DDR_MEM,$1,$2)"); + else if (token == "cudaGetLastError") + token_replace(tokens, i, "cudaGetLastError()", "0"); + else if (token == "cudaStreamCreateWithFlags") + token_replace(tokens, i - 1, + "(cudaStreamCreateWithFlags($1,$2));", + "(aclrtCreateStream($1)); checkAclErrors(aclrtSubscribeReport(acl_jittor_tid,*$1));"); + else if (token == "cudaEventCreate") + token_replace(tokens, i, + "cudaEventCreate($1,$2)", + "aclrtCreateEvent($1)"); + else if (token == "cudaDeviceSynchronize") + token = "aclrtSynchronizeDevice"; + else if (token == "cudaStreamDestroy") + token_replace(tokens, i, "cudaStreamDestroy($1)", + "(aclrtUnSubscribeReport(acl_jittor_tid,$1), aclrtDestroyStream($1))"); + else if (token == "cudaEventDestroy") + token = "aclrtDestroyEvent"; + else if (token == "cudaEventRecord") + token = "aclrtRecordEvent"; + else if (token == "cudaStreamWaitEvent") + token_replace(tokens, i, + "cudaStreamWaitEvent($1,$2,$3)", + "aclrtStreamWaitEvent($1,$2)"); + + if (token.size() && token[0] == 'c') + token = "aclrt" + token.substr(4); + if (endswith(token, "_t")) + token = token.substr(0, token.size() - 2); + edit++; + } + } + else if (token == "_cudaGetErrorEnum") + { + token_replace(tokens, i, "_cudaGetErrorEnum($1)", "(acl_error_to_string($1))"); + edit++; + } + else if (token == "checkCudaErrors") + token = "checkAclErrors"; + else if (token == "JPU") + { + edit++; + string new_code; + if (tokens[i + 2] == "op_compiler") + token_replace(tokens, i, + "JPU(op_compiler($1,$2,$3))", + "acl_jittor_op_compiler($1,$2,$3)"); + else if (tokens[i + 2] == "header") + new_code = "#include \"acl_jittor.h\""; + if (new_code.size()) + token_replace(tokens, i, "JPU($1)", new_code); + } + else if (token == "use_cuda_managed_allocator" && tokens[i + 1][0] == ',') + { + tokens[i + 2] = "0"; // disable unified address } } + if (!edit) + return src; + string new_src = join(tokens, ""); + // if (name == "executor.cc") { + // new_src = string("#include \n#include \n#include \n")+ + // "namespace jittor { void acl_op_exec(Op*); }\n" + + // replace(new_src, "op->do_run_after_prepare(jkl);", + // R"({ + // acl_op_exec(op); + // })"); + // } + if (name == "profiler.cc") + { + new_src = token_replace_all(new_src, ".cc", ".tikcc"); + } + // LOGir << name << (name == "pass_manager.cc"); + if (name == "pass_manager.cc") + { + LOGir << "replace" << name; + new_src = token_replace_all(new_src, "run_pass();", "WTF"); + } + // ???????? + return new_src; + } + catch (const std::exception &e) + { + LOGe << "process acl error:" << e.what(); + LOGe << "name:" << name; + throw; } - // s = token_replace_all(s+";", "auto $1=$2?$3:$4;", "auto $1=$3;if (!($2)) $1=$4;"); - // s = token_replace_all(s+";", "auto $1=$2?$3:$4;", "auto $1=$3;if (!($2)) $1=$4;"); - // if (s.find("::max") != string::npos) { - // s = token_replace_all(s+";", " $1= ::max($2);", "tikcc::scalar_max($1, $2);"); - // } } - new_src = join(ss, ";"); - src = new_src; -} -} + void acl_jittor_op_compiler(string &filename, string &src, bool is_acl, string &extra_flags) + { + if (!is_acl) + return; + string new_src = process_acl(src, "", {}); + new_src = replace(new_src, R"(#include "misc/cuda_atomic.h")", ""); + new_src = replace(new_src, R"(#include "misc/cuda_limits.h")", ""); + new_src = replace(new_src, "__global__", "__ai_device_entry__"); + new_src = token_replace_all(new_src, "__launch_bounds__($1)", ""); + new_src = token_replace_all(new_src, "int thread_num = $1;", "int thread_num = 1;"); + new_src = token_replace_all(new_src, "tn0=std::max(tn0, $1);", ""); + new_src = token_replace_all(new_src, "<<<$1>>>", "<<<1,0>>>"); + new_src = token_replace_all(new_src, "int thread_id = $1;", "int thread_id = 1;"); + // for inc error + new_src = token_replace_all(new_src, "for ($1+=$2)", "for ($1++)"); + // bit op error + new_src = token_replace_all(new_src, "int tnum$1;", ""); + new_src = token_replace_all(new_src, "int p1$1;", ""); + new_src = token_replace_all(new_src, "int p2$1;", ""); + new_src = token_replace_all(new_src, "int tn$1=$2;", "int tn$1=0;"); + new_src = token_replace_all(new_src, "int tid$1=$2;", "int tid$1=0;"); + src = new_src; + + new_src = token_replace_all(new_src, "atomicAdd(&$1,$2);", "$1=$1+$2;"); + // new_src = token_replace_all(new_src, "bool", "int8"); + new_src = token_replace_all(new_src, "::numeric_min()", "-1e30"); + new_src = token_replace_all(new_src, "::numeric_max()", "1e30"); + // TODO: support max + unordered_map opmap = { + // {"::max","tikcc::scalar_max"}, + {"::sqrtf", "tikcc::scalar_sqrt"}}; + auto ss = split(new_src, ";"); + for (auto &s : ss) + { + if (s.find("?") != string::npos) + { + s = token_replace_all(s + ";", "auto $1=$2?$3:$4;", "auto $1=$3;if (!($2)) $1=$4;"); + } + if (s.find("::max") != string::npos) + { + if (s.find("auto") == string::npos) + { + s = token_replace_all(s + ";", " $1=$4::max($2,$3);", " $1=$2;if ($2 < $3) $1=$3;"); + } + else + { + s = token_replace_all(s + ";", "auto $1=$4::max($2,$3);", "auto $1=$2;if ($2 < $3) $1=$3;"); + } + } + for (auto &kv : opmap) + { + if (s.find(kv.first) != string::npos) + { + if (s.find("auto") == string::npos) + { + // $1 = op($2) --> op($1, $2) + s = token_replace_all(s + ";", " $1= " + kv.first + "($2);", kv.second + "($1, $2);"); + } + else + { + // auto $1 = op($2) --> float32 $1; op($1, $2); + s = token_replace_all(s + ";", "auto $1= " + kv.first + "($2);", "float32 $1; " + kv.second + "($1, $2);"); + } + } + } + // s = token_replace_all(s+";", "auto $1=$2?$3:$4;", "auto $1=$3;if (!($2)) $1=$4;"); + // s = token_replace_all(s+";", "auto $1=$2?$3:$4;", "auto $1=$3;if (!($2)) $1=$4;"); + // if (s.find("::max") != string::npos) { + // s = token_replace_all(s+";", " $1= ::max($2);", "tikcc::scalar_max($1, $2);"); + // } + } + new_src = join(ss, ";"); + src = new_src; + } + +} \ No newline at end of file diff --git a/python/jittor/extern/acl/acl_jittor.h b/python/jittor/extern/acl/acl_jittor.h index 0ef90b40..ee9960cb 100644 --- a/python/jittor/extern/acl/acl_jittor.h +++ b/python/jittor/extern/acl/acl_jittor.h @@ -1,6 +1,6 @@ // *************************************************************** -// Copyright (c) 2023 Jittor. All Rights Reserved. -// Maintainers: Dun Liang . +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . // This file is subject to the terms and conditions defined in // file 'LICENSE.txt', which is part of this source code package. // *************************************************************** @@ -10,11 +10,690 @@ std::string acl_error_to_string(aclError error); -namespace jittor { +namespace jittor +{ -EXTERN_LIB uint64_t acl_jittor_tid; -EXTERN_LIB aclrtStream aclstream; + EXTERN_LIB uint64_t acl_jittor_tid; + EXTERN_LIB aclrtStream aclstream; + EXTERN_LIB void *workspaceAddr; -void acl_jittor_op_compiler(string& filename, string& src, bool is_acl, string& extra_flags); + void mallocWorkSpace(uint64_t size); -} + void acl_jittor_op_compiler(string &filename, string &src, bool is_acl, string &extra_flags); + + struct AclOpFunctions + { + // for Unary and Nonzero + std::function getWorkspaceSizeFuncUnaryNonzero; + // for Cast + std::function getWorkspaceSizeFuncCast; + // for Bianry + std::function getWorkspaceSizeFuncBinary; + // for Add and Sub + std::function getWorkspaceSizeFuncAdd; + // for Expand, permute, flip + std::function getWorkspaceSizeFuncExpand; + // for bmm and matmul + std::function getWorkspaceSizeFuncMatmul; + // for conv + std::function getWorkspaceSizeFuncConv; + // for reducesum, mean + std::function getWorkspaceSizeFuncReduceSum; + // for amax and amin + std::function getWorkspaceSizeFuncAmax; + // for conv backward + std::function getWorkspaceSizeFuncConvBackward; + // for proddim + std::function getWorkspaceSizeFuncProdDim; + // for select, where + std::function getWorkspaceSizeFuncSelect; + // for random_uniform and random_normal + std::function getWorkspaceSizeFuncRandom; + // for maxpool + std::function getWorkspaceSizeFuncMaxPool; + // for maxpool backward + std::function getWorkspaceSizeFuncMaxPoolBackward; + // for avgpool + std::function getWorkspaceSizeFuncAvgPool; + // for avgpool backward + std::function getWorkspaceSizeFuncAvgPoolBackward; + // for concat + std::function getWorkspaceSizeFuncConcat; + // for gather + std::function getWorkspaceSizeFuncGather; + // for cumsum + std::function getWorkspaceSizeFuncCumsum; + // for scatter + std::function getWorkspaceSizeFuncScatter; + // for index + std::function getWorkspaceSizeFuncIndex; + // for stridesliceassignv2 + std::function getWorkspaceSizeFuncStridedSliceAssignV2; + // for slicev2 + std::function getWorkspaceSizeFuncSliceV2; + // for indexputimpl + std::function getWorkspaceSizeFuncIndexPutImpl; + // for range + std::function getWorkspaceSizeFuncRange; + // for leaky_relu + std::function getWorkspaceSizeFuncLeakyRelu; + // for leaky_relu backward + std::function getWorkspaceSizeFuncLeakyReluBackward; + // for dropout + std::function getWorkspaceSizeFuncDropout; + // for dropout backward + std::function getWorkspaceSizeFuncDropoutBackward; + // for split with size + std::function getWorkspaceSizeFuncSplitWithSize; + + // for silu + // std::function getWorkspaceSizeFuncSilu; + + // for silu backward + // std::function getWorkspaceSizeFuncSiluBackward; + + // for sigmoid + // std::function getWorkspaceSizeFuncSigmoid; + + // for sigmoid backward + // std::function getWorkspaceSizeFuncSigmoidBackward; + + // for embedding + // std::function getWorkspaceSizeFuncEmbedding; + + // for embedding backward + std::function getWorkspaceSizeFuncEmbeddingBackward; + + // for InplaceMaskedScatter MaskedSelect + // std::function getWorkspaceSizeFuncInplaceMaskedScatter; + std::function executeFunc; + + // for flashattention + std::function + getWorkspaceSizeFuncFalshAttention; + + // for flashattention backward + std::function + getWorkspaceSizeFuncFalshAttentionBackward; + + // for batchnorm + std::function getWorkspaceSizeFuncBatchNorm; + + // for batchnorm backward + std::function getWorkspaceSizeFuncBatchNormBackward; + + // for layernorm + std::function getWorkspaceSizeFuncLayerNorm; + + // for ROPE + std::function + getWorkspaceSizeFuncRotaryPosEmb; + + // 添加一个默认构造函数 + AclOpFunctions() = default; + + // for Unary and Nonzero + AclOpFunctions(std::function gwsf, + std::function execf) + : getWorkspaceSizeFuncUnaryNonzero(gwsf), executeFunc(execf) {} + + // for Cast + AclOpFunctions(std::function gwsf, + std::function execf) + : getWorkspaceSizeFuncCast(gwsf), executeFunc(execf) {} + + // for Binary + AclOpFunctions(std::function gwsf, + std::function execf) + : getWorkspaceSizeFuncBinary(gwsf), executeFunc(execf) {} + // for Add and Sub + AclOpFunctions(std::function gwsf, + std::function execf) + : getWorkspaceSizeFuncAdd(gwsf), executeFunc(execf) {} + + // for Expand, flip + AclOpFunctions(std::function gwsf, + std::function execf) + : getWorkspaceSizeFuncExpand(gwsf), executeFunc(execf) {} + + // for Matmul + AclOpFunctions(std::function gwsf, + std::function execf) + : getWorkspaceSizeFuncMatmul(gwsf), executeFunc(execf) {} + + // for conv + AclOpFunctions(std::function gwsf, + std::function execf) + : getWorkspaceSizeFuncConv(gwsf), executeFunc(execf) {} + + // for reducesum, mean + AclOpFunctions(std::function gwsf, + std::function execf) + : getWorkspaceSizeFuncReduceSum(gwsf), executeFunc(execf) {} + + // for amax amin + AclOpFunctions(std::function gwsf, + std::function execf) + : getWorkspaceSizeFuncAmax(gwsf), executeFunc(execf) {} + + // for conv backward + AclOpFunctions(std::function gwsf, + std::function execf) + : getWorkspaceSizeFuncConvBackward(gwsf), executeFunc(execf) {} + + // for proddim + AclOpFunctions(std::function gwsf, + std::function execf) + : getWorkspaceSizeFuncProdDim(gwsf), executeFunc(execf) {} + + // for select, where + AclOpFunctions(std::function gwsf, + std::function execf) + : getWorkspaceSizeFuncSelect(gwsf), executeFunc(execf) {} + + // for random_normal + AclOpFunctions(std::function gwsf, + std::function execf) + : getWorkspaceSizeFuncRandom(gwsf), executeFunc(execf) {} + + // for maxpool + AclOpFunctions(std::function gwsf, + std::function execf) + : getWorkspaceSizeFuncMaxPool(gwsf), executeFunc(execf) {} + + // for maxpool backward + AclOpFunctions(std::function gwsf, + std::function execf) + : getWorkspaceSizeFuncMaxPoolBackward(gwsf), executeFunc(execf) {} + + // for avgpool + AclOpFunctions(std::function gwsf, + std::function execf) + : getWorkspaceSizeFuncAvgPool(gwsf), executeFunc(execf) {} + + // for avgpool backward + AclOpFunctions(std::function gwsf, + std::function execf) + : getWorkspaceSizeFuncAvgPoolBackward(gwsf), executeFunc(execf) {} + + // for concat + AclOpFunctions(std::function gwsf, + std::function execf) + : getWorkspaceSizeFuncConcat(gwsf), executeFunc(execf) {} + + // for gather + AclOpFunctions(std::function gwsf, + std::function execf) + : getWorkspaceSizeFuncGather(gwsf), executeFunc(execf) {} + + // for cumsum + AclOpFunctions(std::function gwsf, + std::function execf) + : getWorkspaceSizeFuncCumsum(gwsf), executeFunc(execf) {} + + // for scatter + AclOpFunctions(std::function gwsf, + std::function execf) + : getWorkspaceSizeFuncScatter(gwsf), executeFunc(execf) {} + + // for index + AclOpFunctions(std::function gwsf, + std::function execf) + : getWorkspaceSizeFuncIndex(gwsf), executeFunc(execf) {} + + // for stridesliceassignv2 + AclOpFunctions(std::function gwsf, + std::function execf) + : getWorkspaceSizeFuncStridedSliceAssignV2(gwsf), executeFunc(execf) {} + + // for slicev2 + AclOpFunctions(std::function gwsf, + std::function execf) + : getWorkspaceSizeFuncSliceV2(gwsf), executeFunc(execf) {} + + // for indexputimpl + AclOpFunctions(std::function gwsf, + std::function execf) + : getWorkspaceSizeFuncIndexPutImpl(gwsf), executeFunc(execf) {} + + // for range + AclOpFunctions(std::function gwsf, + std::function execf) + : getWorkspaceSizeFuncRange(gwsf), executeFunc(execf) {} + + // for leaky_relu + AclOpFunctions(std::function gwsf, + std::function execf) + : getWorkspaceSizeFuncLeakyRelu(gwsf), executeFunc(execf) {} + + // for leaky_relu backward + AclOpFunctions(std::function gwsf, + std::function execf) + : getWorkspaceSizeFuncLeakyReluBackward(gwsf), executeFunc(execf) {} + + // for dropout + AclOpFunctions(std::function gwsf, + std::function execf) + : getWorkspaceSizeFuncDropout(gwsf), executeFunc(execf) {} + + // for dropout backward + AclOpFunctions(std::function gwsf, + std::function execf) + : getWorkspaceSizeFuncDropoutBackward(gwsf), executeFunc(execf) {} + + // for embedding backward + AclOpFunctions(std::function gwsf, + std::function execf) + : getWorkspaceSizeFuncEmbeddingBackward(gwsf), executeFunc(execf) {} + + // for split with size + AclOpFunctions(std::function gwsf, + std::function execf) + : getWorkspaceSizeFuncSplitWithSize(gwsf), executeFunc(execf) {} + + // for flash attention + AclOpFunctions(std::function + gwsf, + std::function execf) + : getWorkspaceSizeFuncFalshAttention(gwsf), executeFunc(execf) {} + + // for flash attention backward + AclOpFunctions(std::function + gwsf, + std::function execf) + : getWorkspaceSizeFuncFalshAttentionBackward(gwsf), executeFunc(execf) {} + + // for batchnorm + AclOpFunctions(std::function + gwsf, + std::function execf) + : getWorkspaceSizeFuncBatchNorm(gwsf), executeFunc(execf) {} + + // for batchnorm backward + AclOpFunctions(std::function + gwsf, + std::function execf) + : getWorkspaceSizeFuncBatchNormBackward(gwsf), executeFunc(execf) {} + + // for layernorm + AclOpFunctions(std::function + gwsf, + std::function execf) + : getWorkspaceSizeFuncLayerNorm(gwsf), executeFunc(execf) {} + + // for ROPE + AclOpFunctions(std::function gwsf, + std::function execf) + : getWorkspaceSizeFuncRotaryPosEmb(gwsf), executeFunc(execf) {} + }; + + static std::unordered_map aclOpFuncMap = { + {"Abs", AclOpFunctions(aclnnAbsGetWorkspaceSize, aclnnAbs)}, + {"Exp", AclOpFunctions(aclnnExpGetWorkspaceSize, aclnnExp)}, + {"Log", AclOpFunctions(aclnnLogGetWorkspaceSize, aclnnLog)}, + {"Sqrt", AclOpFunctions(aclnnSqrtGetWorkspaceSize, aclnnSqrt)}, + {"Ceil", AclOpFunctions(aclnnCeilGetWorkspaceSize, aclnnCeil)}, + {"Floor", AclOpFunctions(aclnnFloorGetWorkspaceSize, aclnnFloor)}, + {"Round", AclOpFunctions(aclnnRoundGetWorkspaceSize, aclnnRound)}, + {"Sin", AclOpFunctions(aclnnSinGetWorkspaceSize, aclnnSin)}, + {"Cos", AclOpFunctions(aclnnCosGetWorkspaceSize, aclnnCos)}, + {"Tan", AclOpFunctions(aclnnTanGetWorkspaceSize, aclnnTan)}, + {"Asin", AclOpFunctions(aclnnAsinGetWorkspaceSize, aclnnAsin)}, + {"Acos", AclOpFunctions(aclnnAcosGetWorkspaceSize, aclnnAcos)}, + {"Atan", AclOpFunctions(aclnnAtanGetWorkspaceSize, aclnnAtan)}, + {"Sinh", AclOpFunctions(aclnnSinhGetWorkspaceSize, aclnnSinh)}, + {"Cosh", AclOpFunctions(aclnnCoshGetWorkspaceSize, aclnnCosh)}, + {"Tanh", AclOpFunctions(aclnnTanhGetWorkspaceSize, aclnnTanh)}, + {"Asinh", AclOpFunctions(aclnnAsinhGetWorkspaceSize, aclnnAsinh)}, + {"Acosh", AclOpFunctions(aclnnAcoshGetWorkspaceSize, aclnnAcosh)}, + {"Atanh", AclOpFunctions(aclnnAtanhGetWorkspaceSize, aclnnAtanh)}, + {"Sigmoid", AclOpFunctions(aclnnSigmoidGetWorkspaceSize, aclnnSigmoid)}, + {"Erf", AclOpFunctions(aclnnErfGetWorkspaceSize, aclnnErf)}, + {"Erfinv", AclOpFunctions(aclnnErfinvGetWorkspaceSize, aclnnErfinv)}, + {"LogicalNot", AclOpFunctions(aclnnLogicalNotGetWorkspaceSize, aclnnLogicalNot)}, + {"BitwiseNot", AclOpFunctions(aclnnBitwiseNotGetWorkspaceSize, aclnnBitwiseNot)}, + {"Neg", AclOpFunctions(aclnnNegGetWorkspaceSize, aclnnNeg)}, + {"Cast", AclOpFunctions(aclnnCastGetWorkspaceSize, aclnnCast)}, + {"Maximum", AclOpFunctions(aclnnMaximumGetWorkspaceSize, aclnnMaximum)}, + {"Minimum", AclOpFunctions(aclnnMinimumGetWorkspaceSize, aclnnMinimum)}, + {"Add", AclOpFunctions(aclnnAddGetWorkspaceSize, aclnnAdd)}, + {"Sub", AclOpFunctions(aclnnSubGetWorkspaceSize, aclnnSub)}, + {"Mul", AclOpFunctions(aclnnMulGetWorkspaceSize, aclnnMul)}, + {"RealDiv", AclOpFunctions(aclnnDivGetWorkspaceSize, aclnnDiv)}, + {"FloorDiv", AclOpFunctions(aclnnFloorDivideGetWorkspaceSize, aclnnFloorDivide)}, + {"LessEqual", AclOpFunctions(aclnnLeTensorGetWorkspaceSize, aclnnLeTensor)}, + {"Less", AclOpFunctions(aclnnLtTensorGetWorkspaceSize, aclnnLtTensor)}, + {"GreaterEqual", AclOpFunctions(aclnnGeTensorGetWorkspaceSize, aclnnGeTensor)}, + {"Greater", AclOpFunctions(aclnnGtTensorGetWorkspaceSize, aclnnGtTensor)}, + {"Equal", AclOpFunctions(aclnnEqTensorGetWorkspaceSize, aclnnEqTensor)}, + {"NotEqual", AclOpFunctions(aclnnNeTensorGetWorkspaceSize, aclnnNeTensor)}, + {"LogicalAnd", AclOpFunctions(aclnnLogicalAndGetWorkspaceSize, aclnnLogicalAnd)}, + {"LogicalOr", AclOpFunctions(aclnnLogicalOrGetWorkspaceSize, aclnnLogicalOr)}, + {"LogicalXor", AclOpFunctions(aclnnLogicalXorGetWorkspaceSize, aclnnLogicalXor)}, + {"BitwiseAnd", AclOpFunctions(aclnnBitwiseAndTensorGetWorkspaceSize, aclnnBitwiseAndTensor)}, + {"BitwiseOr", AclOpFunctions(aclnnBitwiseOrTensorGetWorkspaceSize, aclnnBitwiseOrTensor)}, + {"BitwiseXor", AclOpFunctions(aclnnBitwiseXorTensorGetWorkspaceSize, aclnnBitwiseXorTensor)}, + {"Pow", AclOpFunctions(aclnnPowTensorTensorGetWorkspaceSize, aclnnPowTensorTensor)}, + {"Expand", AclOpFunctions(aclnnExpandGetWorkspaceSize, aclnnExpand)}, + {"MatMul", AclOpFunctions(aclnnMatmulGetWorkspaceSize, aclnnMatmul)}, + {"BatchMatMul", AclOpFunctions(aclnnBatchMatMulGetWorkspaceSize, aclnnBatchMatMul)}, + {"ReduceMax", AclOpFunctions(aclnnAmaxGetWorkspaceSize, aclnnAmax)}, + {"ReduceMin", AclOpFunctions(aclnnAminGetWorkspaceSize, aclnnAmin)}, + {"ReduceSum", AclOpFunctions(aclnnReduceSumGetWorkspaceSize, aclnnReduceSum)}, + {"Triu", AclOpFunctions(aclnnTriuGetWorkspaceSize, aclnnTriu)}, + {"Conv2d", AclOpFunctions(aclnnConvolutionGetWorkspaceSize, aclnnConvolution)}, + {"Conv2dBackward", AclOpFunctions(aclnnConvolutionBackwardGetWorkspaceSize, aclnnConvolutionBackward)}, + {"ReduceMean", AclOpFunctions(aclnnMeanGetWorkspaceSize, aclnnMean)}, + // {"ReduceProd", AclOpFunctions(aclnnProdDimGetWorkspaceSize, aclnnProdDim)}, + {"Select", AclOpFunctions(aclnnSWhereGetWorkspaceSize, aclnnSWhere)}, + {"RandomUniform", AclOpFunctions(aclnnInplaceUniformGetWorkspaceSize, aclnnInplaceUniform)}, + {"RandomNormal", AclOpFunctions(aclnnInplaceNormalGetWorkspaceSize, aclnnInplaceNormal)}, + {"Transpose", AclOpFunctions(aclnnPermuteGetWorkspaceSize, aclnnPermute)}, + {"Maxpool", AclOpFunctions(aclnnMaxPool2dWithIndicesGetWorkspaceSize, aclnnMaxPool2dWithIndices)}, + {"MaxpoolBackward", AclOpFunctions(aclnnMaxPool2dWithIndicesBackwardGetWorkspaceSize, aclnnMaxPool2dWithIndicesBackward)}, + {"Avgpool", AclOpFunctions(aclnnAvgPool2dGetWorkspaceSize, aclnnAvgPool2d)}, + {"AvgpoolBackward", AclOpFunctions(aclnnAvgPool2dBackwardGetWorkspaceSize, aclnnAvgPool2dBackward)}, + {"Flip", AclOpFunctions(aclnnFlipGetWorkspaceSize, aclnnFlip)}, + {"Concat", AclOpFunctions(aclnnCatGetWorkspaceSize, aclnnCat)}, + {"Gather", AclOpFunctions(aclnnGatherGetWorkspaceSize, aclnnGather)}, + {"Cumsum", AclOpFunctions(aclnnCumsumGetWorkspaceSize, aclnnCumsum)}, + {"Index", AclOpFunctions(aclnnIndexGetWorkspaceSize, aclnnIndex)}, + {"Scatter", AclOpFunctions(aclnnScatterGetWorkspaceSize, aclnnScatter)}, + {"Nonzero", AclOpFunctions(aclnnNonzeroGetWorkspaceSize, aclnnNonzero)}, + {"Where", AclOpFunctions(aclnnSWhereGetWorkspaceSize, aclnnSWhere)}, + {"Floor", AclOpFunctions(aclnnFloorGetWorkspaceSize, aclnnFloor)}, + {"StridedSliceAssignV2", AclOpFunctions(aclnnStridedSliceAssignV2GetWorkspaceSize, aclnnStridedSliceAssignV2)}, + {"SliceV2", AclOpFunctions(aclnnSliceV2GetWorkspaceSize, aclnnSliceV2)}, + {"IndexPutImpl", AclOpFunctions(aclnnIndexPutImplGetWorkspaceSize, aclnnIndexPutImpl)}, + {"IndexPutImplAccumulate", AclOpFunctions(aclnnIndexPutImplGetWorkspaceSize, aclnnIndexPutImpl)}, + {"Range", AclOpFunctions(aclnnRangeGetWorkspaceSize, aclnnRange)}, + {"ReLU", AclOpFunctions(aclnnReluGetWorkspaceSize, aclnnRelu)}, + {"LeakyReLU", AclOpFunctions(aclnnLeakyReluGetWorkspaceSize, aclnnLeakyRelu)}, + {"LeakyReLUBackward", AclOpFunctions(aclnnLeakyReluBackwardGetWorkspaceSize, aclnnLeakyReluBackward)}, + {"Dropout", AclOpFunctions(aclnnDropoutGetWorkspaceSize, aclnnDropout)}, + {"DropoutBackward", AclOpFunctions(aclnnDropoutBackwardGetWorkspaceSize, aclnnDropoutBackward)}, + {"SiLU", AclOpFunctions(aclnnSiluGetWorkspaceSize, aclnnSilu)}, + {"SiLUBackward", AclOpFunctions(aclnnSiluBackwardGetWorkspaceSize, aclnnSiluBackward)}, + {"Sigmoid", AclOpFunctions(aclnnSigmoidGetWorkspaceSize, aclnnSigmoid)}, + {"SigmoidBackward", AclOpFunctions(aclnnSigmoidBackwardGetWorkspaceSize, aclnnSigmoidBackward)}, + {"Embedding", AclOpFunctions(aclnnEmbeddingGetWorkspaceSize, aclnnEmbedding)}, + {"EmbeddingBackward", AclOpFunctions(aclnnEmbeddingDenseBackwardGetWorkspaceSize, aclnnEmbeddingDenseBackward)}, + {"InplaceMaskedScatter", AclOpFunctions(aclnnInplaceMaskedScatterGetWorkspaceSize, aclnnInplaceMaskedScatter)}, + {"MaskedSelect", AclOpFunctions(aclnnMaskedSelectGetWorkspaceSize, aclnnMaskedSelect)}, + {"SplitWithSize", AclOpFunctions(aclnnSplitWithSizeGetWorkspaceSize, aclnnSplitWithSize)}, + {"Softmax", AclOpFunctions(aclnnSoftmaxGetWorkspaceSize, aclnnSoftmax)}, + {"SoftmaxBackward", AclOpFunctions(aclnnSoftmaxBackwardGetWorkspaceSize, aclnnSoftmaxBackward)}, + {"FlashAttention", AclOpFunctions(aclnnFlashAttentionScoreV2GetWorkspaceSize, aclnnFlashAttentionScoreV2)}, + {"FlashAttentionBackward", AclOpFunctions(aclnnFlashAttentionScoreGradV2GetWorkspaceSize, aclnnFlashAttentionScoreGradV2)}, + {"BatchNorm", AclOpFunctions(aclnnBatchNormGetWorkspaceSize, aclnnBatchNorm)}, + {"BatchNormBackward", AclOpFunctions(aclnnBatchNormBackwardGetWorkspaceSize, aclnnBatchNormBackward)}, + {"LayerNorm", AclOpFunctions(aclnnLayerNormGetWorkspaceSize, aclnnLayerNorm)}, + {"RotaryPosEmb", AclOpFunctions(aclnnApplyRotaryPosEmbGetWorkspaceSize, aclnnApplyRotaryPosEmb)}, + {"Stack", AclOpFunctions(aclnnStackGetWorkspaceSize, aclnnStack)}, + {"NanToNum", AclOpFunctions(aclnnNanToNumGetWorkspaceSize, aclnnNanToNum)}, + }; + + struct AclOpAttr + { + virtual ~AclOpAttr() {} + }; + + struct ConvAttr : AclOpAttr + { + vector convStrides; + vector convPads; + vector convOutPads; + vector convDilations; + bool convWithBias; + bool is_transposed; + int64_t group; + + // 析构函数 + ~ConvAttr() + { + convStrides.clear(); + convPads.clear(); + convOutPads.clear(); + convDilations.clear(); + } + }; + + struct ReduceAttr : AclOpAttr + { + vector axes; + // for proddim + int64_t prod_dim; + bool keepdims; + + ~ReduceAttr() + { + axes.clear(); + } + }; + + struct RandomAttr : AclOpAttr + { + int64_t seed, offset; + + ~RandomAttr() + { + } + }; + + struct TriuAttr : AclOpAttr + { + int64_t diagonal; + + ~TriuAttr() + { + } + }; + + struct PoolAttr : AclOpAttr + { + vector kernel_size; + vector poolStrides; + vector poolPads; + vector poolDilations; + bool poolCeil; + bool countIncludePad; + + // divisorOverride(const int64_t,计算输入): 表示取平均的除数。数据类型支持INT64。divisorOverride配置为默认值0时表示功能不使能。 + // https://www.hiascend.com/document/detail/zh/canncommercial/80RC2/apiref/appdevgapi/context/aclnnAvgPool2d.md + int64_t divisorOverride = 0; + + // cubeMathType(int8_t,计算输入): host侧的整型,判断Cube单元应该使用哪种计算逻辑进行运算,数据类型支持INT8。对于无特殊说明的数据类型,均保持原始输入数据类型计算。支持的枚举值如下: + // 0:KEEP_DTYPE,保持输入的数据类型进行计算。当输入是FLOAT,Atlas 训练系列产品和Atlas 推理系列产品(Ascend 310P处理器)暂不支持,取0时会报错。 + // 1:ALLOW_FP32_DOWN_PRECISION,允许将输入数据降精度计算。当输入是FLOAT,Atlas 训练系列产品和Atlas 推理系列产品(Ascend 310P处理器)允许转换为FLOAT16计算。 + // 2:USE_FP16,允许转换为数据类型FLOAT16进行计算。当输入数据类型是FLOAT,转换为FLOAT16计算。 + // 3:USE_HF32,允许转换为数据类型HFLOAT32计算。当输入是FLOAT,Atlas 训练系列产品、Atlas 推理系列产品(Ascend 310P处理器)和Atlas A2训练系列产品/Atlas 800I A2推理产品暂不支持,取3时会报错。 + // https://www.hiascend.com/document/detail/zh/canncommercial/80RC2/apiref/appdevgapi/context/aclnnAvgPool2d.md + int8_t cubeMathType = 0; + + // 析构函数 + ~PoolAttr() + { + kernel_size.clear(); + poolStrides.clear(); + poolPads.clear(); + poolDilations.clear(); + } + }; + + struct ConcatAttr : AclOpAttr + { + int64_t tensorNum; + int64_t dim; + + ~ConcatAttr() + { + } + }; + + struct GatherAttr : AclOpAttr + { + int64_t dim; + + ~GatherAttr() + { + } + }; + + struct ScatterAttr : AclOpAttr + { + int64_t axis; + int64_t reduction; + + ~ScatterAttr() + { + } + }; + + struct StrideAttr : AclOpAttr + { + vector begins; + vector ends; + vector steps; + vector axes; + ~StrideAttr() + { + begins.clear(); + ends.clear(); + steps.clear(); + axes.clear(); + } + }; + + struct RangeAttr : AclOpAttr + { + int64_t start; + int64_t end; + int64_t step; + + ~RangeAttr() + { + } + }; + + struct LeakyReluAttr : AclOpAttr + { + float negativeSlope; + bool selfIsResult; + + ~LeakyReluAttr() + { + } + }; + + struct DropoutAttr : AclOpAttr + { + float p; + bool train; + int64_t seed; + int64_t offset; + float scale; + + ~DropoutAttr() + { + } + }; + + struct EmbeddingAttr : AclOpAttr + { + int64_t numEmbeddings; + // int64_t embeddingDim; + int64_t paddingIdx; + bool scaleGradByFreq; + // bool sparse; + // bool isSparse; + // bool isDense; + + ~EmbeddingAttr() + { + } + }; + + struct SplitWithSizeAttr : AclOpAttr + { + vector splitSize; + int64_t dim; + ~SplitWithSizeAttr() + { + splitSize.clear(); + } + }; + + struct SoftmaxAttr : AclOpAttr + { + int64_t dim; + ~SoftmaxAttr() + { + } + }; + + struct BatchNormAttr : AclOpAttr + { + bool is_train; + float momentum; + float eps; + ~BatchNormAttr() + { + } + }; + + struct LayerNormAttr : AclOpAttr + { + float eps; + vector normalizedShape; + int64_t size; + ~LayerNormAttr() + { + normalizedShape.clear(); + } + }; + + struct FlashAttentionAttr : AclOpAttr + { + vector prefix; + vector qStartIdx; + vector kvStartIdx; + float scale; + float keepProb; + int64_t preToken; + int64_t nextToken; + int64_t headNum; + string inputLayout; + int64_t innerPrecise; + int64_t sparseMode; + int64_t psetype; + bool hasRealshift; + bool hasDropmask; + bool hasPaddingmask; + bool hasAttentmask; + + ~FlashAttentionAttr() + { + prefix.clear(); + qStartIdx.clear(); + kvStartIdx.clear(); + } + }; + + struct NanToNumAttr : AclOpAttr + { + float nan; + float posinf; + float neginf; + ~NanToNumAttr() + { + } + }; +} \ No newline at end of file diff --git a/python/jittor/extern/acl/acl_op_exec.cc b/python/jittor/extern/acl/acl_op_exec.cc index 07b35145..b2b4dcb2 100644 --- a/python/jittor/extern/acl/acl_op_exec.cc +++ b/python/jittor/extern/acl/acl_op_exec.cc @@ -1,6 +1,6 @@ // *************************************************************** -// Copyright (c) 2023 Jittor. All Rights Reserved. -// Maintainers: Dun Liang . +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . // This file is subject to the terms and conditions defined in // file 'LICENSE.txt', which is part of this source code package. // *************************************************************** @@ -9,6 +9,8 @@ #include #include #include +#include +#include #include "common.h" #include "op.h" #include "acl_jittor.h" @@ -29,660 +31,472 @@ #include "ops/op_register.h" #include "opt/tuner_manager.h" #include "utils/str_utils.h" +#include "aclnn/aclnn.h" +#include "aclops/aclops.h" +namespace jittor +{ + void free_var_mem(Var *v); -namespace jittor { + unordered_map opname_map = { + // unary op + {ns_cast, "Cast"}, + {ns_negative, "Neg"}, + {ns_abs, "Abs"}, + {ns_exp, "Exp"}, + {ns_log, "Log"}, + {ns_sqrt, "Sqrt"}, + {ns_ceil, "Ceil"}, + {ns_floor, "Floor"}, + {ns_round, "Round"}, + // m(round_int) + // m(floor_int) + // m(ceil_int) + {ns_sin, "Sin"}, + {ns_cos, "Cos"}, + {ns_tan, "Tan"}, + {ns_asin, "Asin"}, + {ns_acos, "Acos"}, + {ns_atan, "Atan"}, + {ns_sinh, "Sinh"}, + {ns_cosh, "Cosh"}, + {ns_tanh, "Tanh"}, + {ns_asinh, "Asinh"}, + {ns_acosh, "Acosh"}, + {ns_atanh, "Atanh"}, + {ns_sigmoid, "Sigmoid"}, + {ns_erf, "Erf"}, + {ns_erfinv, "Erfinv"}, + {ns_logical_not, "LogicalNot"}, + {ns_bitwise_not, "BitwiseNot"}, + // binary op + {ns_pow, "Pow"}, + {ns_maximum, "Maximum"}, + {ns_minimum, "Minimum"}, + {ns_add, "Add"}, + {ns_subtract, "Sub"}, + {ns_multiply, "Mul"}, + {ns_divide, "RealDiv"}, + {ns_floor_divide, "FloorDiv"}, + {ns_mod, "Mod"}, + {ns_less, "Less"}, + {ns_less_equal, "LessEqual"}, + {ns_greater, "Greater"}, + {ns_greater_equal, "GreaterEqual"}, + {ns_equal, "Equal"}, + {ns_not_equal, "NotEqual"}, + {ns_left_shift, "LeftShift"}, + {ns_right_shift, "RightShift"}, + {ns_logical_and, "LogicalAnd"}, + {ns_logical_or, "LogicalOr"}, + {ns_logical_xor, "LogicalXor"}, + {ns_bitwise_and, "BitwiseAnd"}, + {ns_bitwise_or, "BitwiseOr"}, + {ns_bitwise_xor, "BitwiseXor"}, + }; -using std::swap; - -void printDeviceData(const vector& output_desc, const vector& output_data, const string& name = "", bool input=true) { - LOGir << "name: " << name; - if(input) - LOGir << "is input"; - else - LOGir << "is ouput"; - for (size_t i = 0; i < output_desc.size(); ++i) { - void* base_addr = aclGetDataBufferAddr(output_data[i]); - LOGir << "addr of data[" << i << "] :" << base_addr; - size_t num_dims = aclGetTensorDescNumDims(output_desc[i]); - size_t total_size = 1; - std::vector dims(num_dims); - - std::cout << "shape of data: "; - for (size_t j = 0; j < num_dims; ++j) { - aclGetTensorDescDimV2(output_desc[i], j, &dims[j]); - total_size *= dims[j]; - std::cout << dims[j] << ", "; - } - int evey_batch_size = total_size/dims[0]; - std::cout << std::endl; - - // for(int i= 0; i < dims[0]; i++) { - // evey_batch_size = 16; - // std::vector host_buffer(evey_batch_size); - // void* offset_addr = static_cast(base_addr) + i * evey_batch_size * sizeof(float); - // aclrtMemcpy(host_buffer.data(), evey_batch_size * sizeof(float), offset_addr, evey_batch_size * sizeof(float), ACL_MEMCPY_DEVICE_TO_HOST); - // std::cout << "batch[" << i << "]:"; - // for (size_t k = 0; k < evey_batch_size; ++k) { - // std::cout << host_buffer[k] << ", "; - // } - // std::cout << std::endl; - // if(i >= 3) - // break; - // } - } -} - -struct AclOpRunner { - string name; - vector input_desc; - vector output_desc; - vector input_data; - vector output_data; - aclopAttr *attr; - vector> input_host; - vector> input_host_32; - - AclOpRunner(const string& name) : name(name) { - attr = aclopCreateAttr(); - } - - ~AclOpRunner() { - for (auto i : input_desc) aclDestroyTensorDesc(i); - for (auto i : output_desc) aclDestroyTensorDesc(i); - for (auto i : input_data) aclDestroyDataBuffer(i); - for (auto i : output_data) aclDestroyDataBuffer(i); - aclopDestroyAttr(attr); - } - - aclDataType get_dtype(NanoString s) { - if (s == ns_float32) return ACL_FLOAT; - if (s == ns_float16) return ACL_FLOAT16; - if (s == ns_int64) return ACL_INT64; - if (s == ns_int32) return ACL_INT32; - if (s == ns_int8) return ACL_INT8; - if (s == ns_int16) return ACL_INT16; - if (s == ns_uint8) return ACL_UINT8; - if (s == ns_uint16) return ACL_UINT16; - if (s == ns_uint32) return ACL_UINT32; - if (s == ns_bool) return ACL_BOOL; - LOGf << "Not supported dtype: " << s; - return ACL_FLOAT; - } - - void add(Var* v, bool is_input, int format=ACL_FORMAT_ND) { - int64_t shape[v->shape.size()]; - for (int i=0; ishape.size(); i++) shape[i] = v->shape[i]; - - auto desc = aclCreateTensorDesc(get_dtype(v->dtype()), v->shape.size(), &shape[0], (aclFormat)format); - aclSetTensorFormat(desc, (aclFormat)format); - aclSetTensorShape(desc, v->shape.size(), &shape[0]); - LOGv << "aclCreateTensorDesc" << (int)get_dtype(v->dtype()) << v->shape.size() << &shape[0] << format; - auto data = aclCreateDataBuffer(v->mem_ptr, v->size); - LOGv << "aclCreateDataBuffer" << v->mem_ptr << v->size; - ASSERT(desc && data); - if (is_input) { - input_desc.push_back(desc); - input_data.push_back(data); - } else { - output_desc.push_back(desc); - output_data.push_back(data); - } - } - - void add_input_host(vector v, int dtype=ACL_UINT64) { - int64_t shape[1]; - shape[0] = v.size(); - auto desc = aclCreateTensorDesc((aclDataType)dtype, 1, &shape[0], ACL_FORMAT_ND); - aclSetTensorPlaceMent(desc, ACL_MEMTYPE_HOST_COMPILE_INDEPENDENT); - LOGv << "aclCreateTensorDesc" << dtype << 1 << &shape[0] << ACL_FORMAT_ND; - auto data = aclCreateDataBuffer(&v[0], v.size()*sizeof(uint64)); - ASSERT(desc && data); - LOGv << "aclCreateDataBuffer" << &v[0] << v.size()*sizeof(uint64); - input_desc.push_back(desc); - input_data.push_back(data); - input_host.emplace_back(move(v)); - LOGv << "move" << input_host.back().data(); - } - - void add_input_host_scalar(vector v, int dtype=ACL_UINT32) { - int64_t shape[1]; - shape[0] = v.size(); - auto x = (int*)&v[0]; - x[0] = (int32)v[0]; - auto desc = aclCreateTensorDesc((aclDataType)dtype, 0, &shape[0], ACL_FORMAT_ND); - aclSetTensorPlaceMent(desc, ACL_MEMTYPE_HOST_COMPILE_INDEPENDENT); - LOGv << "aclCreateTensorDesc" << dtype << 1 << &shape[0] << ACL_FORMAT_ND; - auto data = aclCreateDataBuffer(&v[0], v.size()*sizeof(uint32)); - ASSERT(desc && data); - LOGv << "aclCreateDataBuffer" << &v[0] << v.size()*sizeof(uint32); - input_desc.push_back(desc); - input_data.push_back(data); - input_host.emplace_back(move(v)); - } - - void add_input_host_nv(NanoVector nv, int dtype=ACL_UINT64) { - vector v(nv.size()); - for (int i=0; i v(nv.size()); - for (int i=0; i value) { - // LOGir << "string vector" << "set_attr" << key << value; - CHECK(aclopSetAttrListInt(attr, key.c_str(), value.size(), &value[0])==0); - } - void set_attr(const string& key, string value) { - // LOGir << "string string" << "set_attr" << key << value; - CHECK(aclopSetAttrString(attr, key.c_str(), value.c_str())==0); - } - void set_attr(const char* key, const char* value) { - // LOGir << "char" << "set_attr" << key << value; - CHECK(aclopSetAttrString(attr, key, value)==0); - } - - void run() { - // printDeviceData(input_desc, input_data, name); - - LOGv << "run" << name << input_desc.size() << output_desc.size(); - if (!PyGILState_Check()) { - ASSERT(0==aclopCompileAndExecuteV2(name.c_str(), input_desc.size(), &input_desc[0], &input_data[0], output_desc.size(), &output_desc[0], &output_data[0], attr, ACL_ENGINE_SYS, ACL_COMPILE_SYS, NULL, aclstream)); - } else { - int ret; - Py_BEGIN_ALLOW_THREADS - ret = aclopCompileAndExecuteV2(name.c_str(), input_desc.size(), &input_desc[0], &input_data[0], output_desc.size(), &output_desc[0], &output_data[0], attr, ACL_ENGINE_SYS, ACL_COMPILE_SYS, NULL, aclstream); - Py_END_ALLOW_THREADS - if (ret != 0) - LOGf << "aclopCompileAndExecuteV2" << name << "failed return" << ret; - } - ASSERT(0==aclrtSynchronizeDevice()); - - // printDeviceData(output_desc, output_data, name, false); - } -}; - -void free_var_mem(Var* v); - - -unordered_map opname_map = { - // unary op - {ns_cast, "Cast"}, - {ns_negative, "Neg"}, - {ns_abs, "Abs"}, - {ns_exp, "Exp"}, - {ns_log, "Log"}, - {ns_sqrt, "Sqrt"}, - {ns_ceil, "Ceil"}, - {ns_floor, "Floor"}, - {ns_round, "Round"}, - // m(round_int) - // m(floor_int) - // m(ceil_int) - {ns_sin, "Sin"}, - {ns_cos, "Cos"}, - {ns_tan, "Tan"}, - {ns_asin, "Asin"}, - {ns_acos, "Acos"}, - {ns_atan, "Atan"}, - {ns_sinh, "Sinh"}, - {ns_cosh, "Cosh"}, - {ns_tanh, "Tanh"}, - {ns_asinh, "Asinh"}, - {ns_acosh, "Acosh"}, - {ns_atanh, "Atanh"}, - {ns_sigmoid, "Sigmoid"}, - {ns_erf, "Erf"}, - {ns_erfinv, "Erfinv"}, - {ns_logical_not, "LogicalNot"}, - {ns_bitwise_not, "BitwiseNot"}, - // binary op - {ns_pow, "Pow"}, - {ns_maximum, "Maximum"}, - {ns_minimum, "Minimum"}, - {ns_add, "Add"}, - {ns_subtract, "Sub"}, - {ns_multiply, "Mul"}, - {ns_divide, "RealDiv"}, - {ns_floor_divide, "FloorDiv"}, - {ns_mod, "Mod"}, - {ns_less, "Less"}, - {ns_less_equal, "LessEqual"}, - {ns_greater, "Greater"}, - {ns_greater_equal, "GreaterEqual"}, - {ns_equal, "Equal"}, - {ns_not_equal, "NotEqual"}, - {ns_left_shift, "LeftShift"}, - {ns_right_shift, "RightShift"}, - {ns_logical_and, "LogicalAnd"}, - {ns_logical_or, "LogicalOr"}, - {ns_logical_xor, "LogicalXor"}, - {ns_bitwise_and, "BitwiseAnd"}, - {ns_bitwise_or, "BitwiseOr"}, - {ns_bitwise_xor, "BitwiseXor"}, - -}; - -void fallback_cpu(Op* op) { - LOGy << "!!! fallback_cpu " << op; - use_cuda = 0; - for (auto v : op->inputs()) { - if (v->mem_ptr && v->allocator->is_cuda()) { - migrate_to_cpu(v, exe.allocator); - } - } - for (auto v : op->outputs()) { - if (v->mem_ptr && v->allocator->is_cuda()) { - migrate_to_cpu(v, exe.allocator); - } - } - op->flags.set(NodeFlags::_cpu); - op->flags.set(NodeFlags::_cuda, 0); - if (op->name() == string("fused")) { - auto fop = (FusedOp*)op; - for (auto op : fop->ops) { - op->flags.set(NodeFlags::_cpu); - op->flags.set(NodeFlags::_cuda, 0); - } - } - op->do_run(); - use_cuda = 1; -} - -/* - check compile - if compiled: exec - else: compile - check is fused - check is relay - else - compile func = try exec - if failed: fallback_cpu - else - try compile - if failed: fallback_cpu -*/ - -extern jit_op_entry_t (*do_compile_hook)(Op*); -jit_op_entry_t do_compile_inner(Op* op); - -void try_exec_and_fallback_cpu(Op* op) { - LOGv << "try_exec_and_fallback_cpu " << op; - auto fop = (FusedOp*)op; - - vector new_alloced; - int fallback = 0; - try { - for (Op* op : fop->ops) { - for (auto out : op->outputs()) { - if (out->mem_ptr) continue; - out->alloc(exe.temp_allocator); - new_alloced.push_back(out); - } - if (op->name() == string("unary")) { - auto uop = (UnaryOp*)op; - AclOpRunner op("..."); - op.add(uop->x, true); - op.add(uop->y, false); - if (uop->ns == ns_cast) { - op.set_attr("dst_type", (int64_t)op.get_dtype(uop->y->dtype())); - } - auto iter = opname_map.find(uop->ns); - ASSERT(iter != opname_map.end()) << "op " << uop->ns << " not found"; - op.name = iter->second; - op.run(); - } else - if (op->name() == string("binary")) { - auto bop = (BinaryOp*)op; - AclOpRunner op("..."); - op.add(bop->x, true); - op.add(bop->y, true); - op.add(bop->z, false); - auto iter = opname_map.find(bop->ns); - ASSERT(iter != opname_map.end()) << "op " << bop->ns << " not found"; - op.name = iter->second; - if (bop->x->dtype() == ns_bool and bop->y->dtype() == ns_bool) - { - // BitwiseOr, BitwiseAnd, BitwiseXor -> LogicalOr, LogicalAnd, LogicalXor - if (bop->ns == ns_bitwise_or) { - op.name = "LogicalOr"; - } else if (bop->ns == ns_bitwise_and) { - op.name = "LogicalAnd"; - } else if (bop->ns == ns_bitwise_xor) { - op.name = "LogicalXor"; - } - } - op.run(); - } else - if (op->name() == string("ternary")) { - auto top = (TernaryOp*)op; - AclOpRunner op("Select"); - op.add(top->cond, true); - op.add(top->x, true); - op.add(top->y, true); - op.add(top->z, false); - op.run(); - } else - if (op->name() == string("array")) { - auto aop = (ArrayOp*)op; - aclrtMemcpy(aop->output->mem_ptr, aop->output->size, aop->ptr(), aop->output->size, ACL_MEMCPY_HOST_TO_DEVICE); - } else - if (op->name() == string("reduce")) { - auto rop = (ReduceOp*)op; - AclOpRunner op(""); - if (rop->ns == ns_add) - op.name = "ReduceSum"; - else if (rop->ns == ns_multiply) - op.name = "ReduceProd"; - else if (rop->ns == ns_maximum) - op.name = "ReduceMax"; - else if (rop->ns == ns_minimum) - op.name = "ReduceMin"; - else if (rop->ns == ns_mean) - op.name = "ReduceMean"; - else - LOGf << "op " << rop->ns << " not supported"; - op.add(rop->x, true); - vector axes; - for (int i=0; ix->shape.size(); i++) - if (rop->reduce_mask & (1<y, false); - op.set_attr("keep_dims", false); - if (rop->ns == ns_mean) { - // operation: An optional int32 from 1(SUM), 2(ASUM), 3(SUMSQ), and 4(MEAN), specifying the reduction algorithm. Defaults to "1". - op.set_attr("operation", 4); - } - op.run(); - } else - if (op->name() == string("broadcast_to")) { - auto bop = (BroadcastToOp*)op; - AclOpRunner op("Expand"); - NanoVector xshape, xshape_bk = bop->x->shape; - NanoVector zshape = bop->z->shape; - for (int i=0; ibcast_mask & (1<x->shape = xshape; - op.add(bop->x, true); - bop->x->shape = xshape_bk; - op.add_input_host_nv(zshape, ACL_INT64); - op.add(bop->z, false); - op.run(); - } - else - if (op->name() == string("fuse_transpose")) { - // replace fuse_transpose with transpose - auto top = (TransposeOp*)op; - AclOpRunner op("Transpose"); - op.add(top->x, true); - op.add(top->y, false); - vector axes; - for (int i=0; iaxes.size(); i++) - axes.push_back(top->axes[i]); - op.add_input_host(axes, ACL_INT64); - op.run(); - } else + void fallback_cpu(Op *op) + { + LOGy << "!!! fallback_cpu " << op; + use_cuda = 0; + for (auto v : op->inputs()) + { + if (v->mem_ptr && v->allocator->is_cuda()) { - LOGf << "op " << op->name() << " not supported"; + migrate_to_cpu(v, exe.allocator); } } - } catch (std::exception& e) { - fallback = 1; - LOGir << "fallback cpu" << e.what(); - } - for (auto v : new_alloced) { - free_var_mem(v); - } - if (fallback) { - fallback_cpu(op); - } -} - -extern int current_seed; -extern int64 current_offset; - -static unordered_map> acl_ops = { -{"curand_random", [¤t_seed, ¤t_offset](Op* op) { - auto _op = (RandomOp*)op; - AclOpRunner runner(_op->type == ns_uniform ? "StatelessRandomUniformV2" : "StatelessRandomNormalV2"); - auto out = op->output(0); - runner.add_input_host_nv(out->shape, ACL_INT64); // shape - runner.add_input_host({current_seed}); // seed - runner.add_input_host({0,current_offset}); // offset - runner.add_input_host_scalar({1}, ACL_INT32); // algorithm - runner.add(out, false); - runner.set_attr("dtype", (int64_t)runner.get_dtype(out->dtype())); - runner.run(); - // aclrtSynchronizeDevice(); - current_offset += out->numel(); -}}, -{"cublas_matmul", [&](Op* op) { - struct MatmulOp : Op { - Var* a, *b, *c; - bool trans_a, trans_b; - }; - auto _op = (MatmulOp*)op; - AclOpRunner runner("MatMul"); - runner.add(_op->a, true); - runner.add(_op->b, true); - runner.add(_op->c, false); - runner.set_attr("transpose_x1", _op->trans_a); - runner.set_attr("transpose_x2", _op->trans_b); - runner.run(); -}}, -{"cublas_batched_matmul", [&](Op* op) { - struct BatchedMatmulOp : Op { - Var* a, *b, *c; - bool adj_x1, adj_x2; - }; - auto _op = (BatchedMatmulOp*)op; - AclOpRunner runner("BatchMatMul"); - runner.add(_op->a, true); - runner.add(_op->b, true); - runner.add(_op->c, false); - runner.set_attr("adj_x1", _op->adj_x1); - runner.set_attr("adj_x2", _op->adj_x2); - runner.run(); -}}, -{"cudnn_conv", [](Op* op) { - struct ConvOp : Op { - Var* x, * w, * y; - int strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups; - string xformat, wformat, yformat; - void run_acl() { - AclOpRunner runner("Conv2D"); - runner.add(x, true, ACL_FORMAT_NCHW); - runner.add(w, true, ACL_FORMAT_NCHW); - runner.add(y, false, ACL_FORMAT_NCHW); - runner.set_attr("strides", vector{1,1,strideh,stridew}); - runner.set_attr("pads", vector{paddingh,paddingh,paddingw,paddingw}); - runner.set_attr("dilations", vector{1,1,dilationh,dilationw}); - runner.set_attr("groups", groups); - ASSERT(xformat=="abcd" && yformat=="abcd" && wformat=="oihw"); - runner.set_attr("data_format", "NCHW"); - runner.run(); + for (auto v : op->outputs()) + { + if (v->mem_ptr && v->allocator->is_cuda()) + { + migrate_to_cpu(v, exe.allocator); + } } - }; - auto _op = (ConvOp*)op; - _op->run_acl(); -}}, -{"cudnn_conv_backward_x", [](Op* op) { - struct ConvBackwardXOp : Op { - Var* w, * dy, * dx; - int xh, xw, strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups; - string xformat, wformat, yformat; - void run_acl() { - AclOpRunner runner("Conv2DBackpropInput"); - runner.add_input_host_nv32(dx->shape); // 10,3,50,50 - // runner.add_input_host_nv32(dy->shape); // 10,3,50,50 - runner.add(w, true, ACL_FORMAT_NCHW); // 4,3,3,3 - aclSetTensorDescName(runner.input_desc.back(), "filter"); - runner.add(dy, true, ACL_FORMAT_NCHW); // 10,4,48,48 - aclSetTensorDescName(runner.input_desc.back(), "out_backprop"); - runner.add(dx, false, ACL_FORMAT_NCHW); // 10,3,50,50 - aclSetTensorDescName(runner.input_desc.back(), "y"); - runner.set_attr("strides", vector{1,1,strideh,stridew}); - runner.set_attr("pads", vector{paddingh,paddingh,paddingw,paddingw}); - runner.set_attr("dilations", vector{1,1,dilationh,dilationw}); - runner.set_attr("groups", groups); - runner.set_attr("data_format", "NCHW"); - // runner.set_attr("dataFormat", "NCHW"); - // runner.set_attr("data_format", "NCHW"); - ASSERT(xformat=="abcd" && yformat=="abcd" && wformat=="oihw"); - runner.run(); + op->flags.set(NodeFlags::_cpu); + op->flags.set(NodeFlags::_cuda, 0); + if (op->name() == string("fused")) + { + auto fop = (FusedOp *)op; + for (auto op : fop->ops) + { + op->flags.set(NodeFlags::_cpu); + op->flags.set(NodeFlags::_cuda, 0); + } } - }; - auto _op = (ConvBackwardXOp*)op; - _op->run_acl(); -}}, -{"cudnn_conv_backward_w", [](Op* op) { - struct ConvBackwardWOp : Op { - Var* x, * dy, * dw; - int kh, kw, strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups; - string xformat, wformat, yformat; - void run_acl() { - AclOpRunner runner("Conv2DBackpropFilter"); - runner.add(x, true, ACL_FORMAT_NCHW); - runner.add_input_host_nv32(dw->shape); - runner.add(dy, true, ACL_FORMAT_NCHW); - runner.add(dw, false, ACL_FORMAT_NCHW); - runner.set_attr("strides", vector{1,1,strideh,stridew}); - runner.set_attr("pads", vector{paddingh,paddingh,paddingw,paddingw}); - runner.set_attr("dilations", vector{1,1,dilationh,dilationw}); - runner.set_attr("groups", groups); - runner.set_attr("data_format", "NCHW"); - // runner.set_attr("dataFormat", "NCHW"); - // runner.set_attr("data_format", "NCHW"); - // runner.set_attr("data_origin_format", "NCHW"); - ASSERT(xformat=="abcd" && yformat=="abcd" && wformat=="oihw"); - runner.run(); + op->do_run(); + use_cuda = 1; + } + + /* + check compile + if compiled: exec + else: compile + check is fused + check is relay + else + compile func = try exec + if failed: fallback_cpu + else + try compile + if failed: fallback_cpu + */ + + extern jit_op_entry_t (*do_compile_hook)(Op *); + jit_op_entry_t do_compile_inner(Op *op); + + void try_exec_and_fallback_cpu(Op *op) + { + aclrtSynchronizeStream(aclstream); + auto fop = (FusedOp *)op; + + std::set new_alloced; + map op_indeg; + map var_outdeg; + std::queue queue; + + for (Op *op : fop->ops) + op_indeg[op] = 0; + + map> out_map; + map> from; + + int len = 0; + for (Op *v : fop->ops) + { + for (auto in : v->inputs()) + from[in].push_back(v); + ++len; } + for (Op *u : fop->ops) + { + for (auto out : u->outputs()) + { + if (from.find(out) != from.end()) + { + for (auto v : from[out]) + { + ++op_indeg[v]; + ++var_outdeg[out]; + out_map[u].push_back(v); + } + } + } + } + for (Op *op : fop->ops) + { + if (op_indeg[op] == 0) + queue.push(op); + } + + int total = 0; + int fallback = 0; + try + { + while (!queue.empty()) + { + total++; + + for (auto in : op->inputs()) + { + ASSERT(in->mem_ptr); + } + auto op = queue.front(); + queue.pop(); + for (auto out : op->outputs()) + { + if (out->mem_ptr) + continue; + out->alloc(exe.allocator); + new_alloced.insert(out); + } + for (auto out : out_map[op]) + { + --op_indeg[out]; + if (op_indeg[out] == 0) + queue.push(out); + } + if (op->name() == string("unary")) + { + auto uop = (UnaryOp *)op; + UnaryOpRunner op; + op.add(uop->x, true); + op.add(uop->y, false); + auto iter = opname_map.find(uop->ns); + ASSERT(iter != opname_map.end()) << "op " << uop->ns << " not found"; + op.name = iter->second; + op.jt_name = uop->name(); + op.run(); + } + else if (op->name() == string("binary")) + { + auto bop = (BinaryOp *)op; + BinaryOpRunner op; + op.add(bop->x, true); + op.add(bop->y, true); + op.add(bop->z, false); + auto iter = opname_map.find(bop->ns); + ASSERT(iter != opname_map.end()) << "op " << bop->ns << " not found"; + op.name = iter->second; + op.jt_name = bop->name(); + + if (bop->x->dtype() == ns_bool and bop->y->dtype() == ns_bool) + { + // BitwiseOr, BitwiseAnd, BitwiseXor -> LogicalOr, LogicalAnd, LogicalXor + if (bop->ns == ns_bitwise_or) + { + op.name = "LogicalOr"; + } + else if (bop->ns == ns_bitwise_and) + { + op.name = "LogicalAnd"; + } + else if (bop->ns == ns_bitwise_xor) + { + op.name = "LogicalXor"; + } + } + op.run(); + } + else if (op->name() == string("ternary")) + { + auto top = (TernaryOp *)op; + TernaryOpRunner op; + op.add(top->cond, true); + op.add(top->x, true); + op.add(top->y, true); + op.add(top->z, false); + op.run(); + } + else if (op->name() == string("array")) + { + auto aop = (ArrayOp *)op; + aclrtMemcpy(aop->output->mem_ptr, aop->output->size, aop->ptr(), aop->output->size, ACL_MEMCPY_HOST_TO_DEVICE); + } + else if (op->name() == string("reduce")) + { + auto rop = (ReduceOp *)op; + ReduceOpRunner op; + if (rop->ns == ns_add) + op.op_idx = 9; + else if (rop->ns == ns_multiply) + // TODO unsupported the multi dim + op.op_idx = 999; + else if (rop->ns == ns_maximum) + op.op_idx = 11; + else if (rop->ns == ns_minimum) + op.op_idx = 12; + else if (rop->ns == ns_mean) + op.op_idx = 10; + else + LOGf << "op " << rop->ns << " not supported"; + op.add(rop->x, true); + + ReduceAttr *attr = new ReduceAttr(); + for (int i = 0; i < rop->x->shape.size(); i++) + if (rop->reduce_mask & (1 << i)) + attr->axes.push_back(i); + if (rop->x->shape.size() == rop->y->shape.size()) + attr->keepdims = true; + else + attr->keepdims = false; + + op.op_attr.reset(attr); + op.add(rop->y, false); + op.run(); + aclrtSynchronizeStream(aclstream); + } + else if (op->name() == string("broadcast_to")) + { + auto bop = (BroadcastToOp *)op; + ExpandOpRunner op; + op.jt_name = "expand"; + NanoVector xshape, xshape_bk = bop->x->shape; + NanoVector zshape = bop->z->shape; + + for (int i = 0; i < zshape.size(); i++) + { + if (bop->bcast_mask & (1 << i)) + { + xshape.push_back(1); + } + else + { + xshape.push_back(zshape[i]); + } + } + bop->x->shape = xshape; + op.add(bop->x, true); + // bop->x->shape = xshape_bk; + op.add(bop->z, false); + op.run(); + bop->x->shape = xshape_bk; + aclrtSynchronizeStream(aclstream); + } + else if (op->name() == string("fuse_transpose")) + { + // replace fuse_transpose with transpose + auto top = (TransposeOp *)op; + TransposeOpRunner op; + op.add(top->x, true); + op.add(top->y, false); + op.jt_name = "transpose"; + + ReduceAttr *attr = new ReduceAttr(); + for (int i = 0; i < top->axes.size(); i++) + attr->axes.push_back(top->axes[i]); + op.op_attr.reset(attr); + + op.run(); + } + else + { + LOGf << "op " << op->name() << " not supported"; + } + + for (auto in : op->inputs()) + { + --var_outdeg[in]; + if (var_outdeg[in] == 0) + { + if (new_alloced.find(in) != new_alloced.end()) + { + free_var_mem(in); + new_alloced.erase(in); + } + } + } + } + } + catch (std::exception &e) + { + fallback = 1; + LOGir << "fallback cpu" << e.what(); + } + for (auto v : new_alloced) + { + free_var_mem(v); + } + if (fallback) + { + fallback_cpu(op); + } + } + + extern int current_seed; + extern int64 current_offset; + + static unordered_map> acl_ops = { + {"curand_random", [¤t_seed, ¤t_offset](Op *op) + { + auto _op = (RandomOp *)op; + RandomOpRunner runner(_op->type == ns_uniform ? "RandomUniform" : "RandomNormal"); + auto out = op->output(0); + RandomAttr *attr = new RandomAttr(); + attr->seed = current_seed; + attr->offset = current_offset; + runner.jt_name = "random"; + runner.op_attr.reset(attr); + + runner.add(out, false); + runner.run(); + current_offset += out->numel(); + }}, }; - auto _op = (ConvBackwardWOp*)op; - _op->run_acl(); -}}, -// {"cub_arg_reduce", } -}; -static void exec_mapped_acl_ops(Op* op) { - auto iter = acl_ops.find(op->name()); - if (iter != acl_ops.end()) { - LOGv << "exec acl op " << op->name() << op; - iter->second(op); - } else { - LOGf << "op " << op->name() << " not supported"; + static void exec_mapped_acl_ops(Op *op) + { + auto iter = acl_ops.find(op->name()); + if (iter != acl_ops.end()) + { + LOGv << "exec acl op " << op->name() << op; + iter->second(op); + } + else + { + LOGf << "op " << op->name() << " not supported"; + } } -} -static jit_op_entry_t acl_do_compile(Op* op) { - LOGv << "compile" << op; - OpCompiler oc(op); - string* src = &oc.src; - for (auto op_type : op_types) - op_type->post_pass(&oc); - string src_after_passes; - // if is fused op - if (oc.op) { - TunerManager tm(&oc); - src_after_passes = tm.tune(); - src = &src_after_passes; - } - op->compile_optimize(*src); - if (!op->flags.get(NodeFlags::_cuda)) { - LOGv << "compile cpu"; - return oc.compile(op->get_jit_key(get_jk()), *src); - } - if (op->name() == string("fused")) { - FusedOp* fop = (FusedOp*)op; - // if is a relayed op - if (fop->context->vrm.relay_groups.size()) { - LOGv << "relay fused op"; + static jit_op_entry_t acl_do_compile(Op *op) + { + LOGv << "compile" << op; + OpCompiler oc(op); + string *src = &oc.src; + for (auto op_type : op_types) + op_type->post_pass(&oc); + string src_after_passes; + // if is fused op + if (oc.op) + { + TunerManager tm(&oc); + src_after_passes = tm.tune(); + src = &src_after_passes; + } + op->compile_optimize(*src); + if (!op->flags.get(NodeFlags::_cuda)) + { + LOGv << "compile cpu"; return oc.compile(op->get_jit_key(get_jk()), *src); - } else { - return &try_exec_and_fallback_cpu; } - } else - if (op->name() == string("code")) { - CodeOp* cop = (CodeOp*)op; - if (cop->cuda_src.find("acl") != string::npos) { - LOGv << "compile acl op"; + if (op->name() == string("fused")) + { + FusedOp *fop = (FusedOp *)op; + // if is a relayed op + if (fop->context->vrm.relay_groups.size()) + { + LOGv << "relay fused op"; + return oc.compile(op->get_jit_key(get_jk()), *src); + } + else + { + return &try_exec_and_fallback_cpu; + } + } + else if (op->name() == string("code")) + { + CodeOp *cop = (CodeOp *)op; + if (cop->cuda_src.find("acl") != string::npos) + { + LOGv << "compile acl op"; + return oc.compile(op->get_jit_key(get_jk()), *src); + } + else + { + return &exec_mapped_acl_ops; + } + } + else if (strncmp(op->name(), "hccl", 4) == 0) + { + LOGv << "Compiling HCCL op: " << op->name(); return oc.compile(op->get_jit_key(get_jk()), *src); - } else { + } + else + { + LOGv << "compile finish" << op; return &exec_mapped_acl_ops; } - } else - { - LOGv << "compile finish" << op; - return &exec_mapped_acl_ops; + return do_compile_inner(op); } - return do_compile_inner(op); -} -// from op_register.cc -extern unordered_map op_info_map; + // from op_register.cc + extern unordered_map op_info_map; -void init_acl_ops() { - do_compile_hook = acl_do_compile; - vector to_erase; - for (auto& kv : op_info_map) { - if (startswith(kv.first, "cu") && acl_ops.count(kv.first) == 0) { - to_erase.push_back(kv.first); + void init_acl_ops() + { + do_compile_hook = acl_do_compile; + vector to_erase; + for (auto &kv : op_info_map) + { + if (startswith(kv.first, "cu") && acl_ops.count(kv.first) == 0) + { + to_erase.push_back(kv.first); + } + } + for (auto &k : to_erase) + { + LOGv << "op not supported: " << k << ", erase it."; + op_info_map.erase(k); } } - for (auto& k : to_erase) { - LOGv << "op not supported: " << k << ", erase it."; - op_info_map.erase(k); - } -} - } // jittor \ No newline at end of file diff --git a/python/jittor/extern/acl/aclnn/aclnn.cc b/python/jittor/extern/acl/aclnn/aclnn.cc new file mode 100644 index 00000000..3452c7be --- /dev/null +++ b/python/jittor/extern/acl/aclnn/aclnn.cc @@ -0,0 +1,58 @@ +#include +#include +#include "aclnn.h" + +int64_t GetShapeSize(const std::vector& shape) { + int64_t shapeSize = 1; + for (auto i : shape) { + shapeSize *= i; + } + return shapeSize; +} + +void PrintOutResult(std::vector &shape, void** deviceAddr) { + auto size = GetShapeSize(shape); + std::vector resultData(size, 0); + auto ret = aclrtMemcpy(resultData.data(), resultData.size() * sizeof(resultData[0]), + *deviceAddr, size * sizeof(resultData[0]), ACL_MEMCPY_DEVICE_TO_HOST); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("copy result from device to host failed. ERROR: %d\n", ret); return); + for (int64_t i = 0; i < size; i++) { + LOG_PRINT("mean result[%ld] is: %d\n", i, resultData[i]); + } +} + +/*int Init(int32_t deviceId) { + // 固定写法,AscendCL初始化 + auto ret = aclInit(nullptr); + CHECK_RET(ret == ACL_SUCCESS or ret == 100002, LOG_PRINT("aclInit failed. ERROR: %d\n", ret); return ret); + ret = aclrtSetDevice(deviceId); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtSetDevice failed. ERROR: %d\n", ret); return ret); + //ret = aclrtCreateStream(stream); + //CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtCreateStream failed. ERROR: %d\n", ret); return ret); + return 0; +}*/ + +/* +template +int CreateAclTensor(const std::vector& hostData, const std::vector& shape, void** deviceAddr, + aclDataType dataType, aclTensor** tensor) { + auto size = GetShapeSize(shape) * sizeof(T); + // 调用aclrtMalloc申请device侧内存 + auto ret = aclrtMalloc(deviceAddr, size, ACL_MEM_MALLOC_HUGE_FIRST); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMalloc failed. ERROR: %d\n", ret); return ret); + // 调用aclrtMemcpy将host侧数据拷贝到device侧内存上 + ret = aclrtMemcpy(*deviceAddr, size, hostData.data(), size, ACL_MEMCPY_HOST_TO_DEVICE); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMemcpy failed. ERROR: %d\n", ret); return ret); + + // 计算连续tensor的strides + std::vector strides(shape.size(), 1); + for (int64_t i = shape.size() - 2; i >= 0; i--) { + strides[i] = shape[i + 1] * strides[i + 1]; + } + + // 调用aclCreateTensor接口创建aclTensor + *tensor = aclCreateTensor(shape.data(), shape.size(), dataType, strides.data(), 0, aclFormat::ACL_FORMAT_ND, + shape.data(), shape.size(), *deviceAddr); + return 0; +}*/ + diff --git a/python/jittor/extern/acl/aclnn/aclnn.h b/python/jittor/extern/acl/aclnn/aclnn.h new file mode 100644 index 00000000..0a4eef19 --- /dev/null +++ b/python/jittor/extern/acl/aclnn/aclnn.h @@ -0,0 +1,134 @@ +#include +#include +#include "acl.h" +// unary +#include "aclnnop/aclnn_abs.h" +#include "aclnnop/aclnn_neg.h" +#include "aclnnop/aclnn_exp.h" +#include "aclnnop/aclnn_log.h" +#include "aclnnop/aclnn_sqrt.h" +#include "aclnnop/aclnn_ceil.h" +#include "aclnnop/aclnn_floor.h" +#include "aclnnop/aclnn_round.h" +#include "aclnnop/aclnn_sin.h" +#include "aclnnop/aclnn_cos.h" +#include "aclnnop/aclnn_tan.h" +#include "aclnnop/aclnn_asin.h" +#include "aclnnop/aclnn_acos.h" +#include "aclnnop/aclnn_atan.h" +#include "aclnnop/aclnn_sinh.h" +#include "aclnnop/aclnn_cosh.h" +#include "aclnnop/aclnn_tanh.h" +#include "aclnnop/aclnn_asinh.h" +#include "aclnnop/aclnn_acosh.h" +#include "aclnnop/aclnn_atanh.h" +#include "aclnnop/aclnn_sigmoid.h" +#include "aclnnop/aclnn_erf.h" +#include "aclnnop/aclnn_erfinv.h" +#include "aclnnop/aclnn_logical_not.h" +#include "aclnnop/aclnn_bitwise_not.h" +#include "aclnnop/aclnn_cast.h" +#include "aclnnop/aclnn_nonzero.h" +// binary +#include "aclnnop/aclnn_maximum.h" +#include "aclnnop/aclnn_minimum.h" +#include "aclnnop/aclnn_add.h" +#include "aclnnop/aclnn_sub.h" +#include "aclnnop/aclnn_mul.h" +#include "aclnnop/aclnn_div.h" +#include "aclnnop/aclnn_floor_divide.h" +#include "aclnnop/aclnn_le_tensor.h" +#include "aclnnop/aclnn_lt_tensor.h" +#include "aclnnop/aclnn_ge_tensor.h" +#include "aclnnop/aclnn_gt_tensor.h" +#include "aclnnop/aclnn_eq_tensor.h" +#include "aclnnop/aclnn_ne_tensor.h" +#include "aclnnop/aclnn_logical_and.h" +#include "aclnnop/aclnn_logical_or.h" +#include "aclnnop/aclnn_logical_xor.h" +#include "aclnnop/aclnn_bitwise_and_tensor.h" +#include "aclnnop/aclnn_bitwise_or_tensor.h" +#include "aclnnop/aclnn_bitwise_xor_tensor.h" +#include "aclnnop/aclnn_pow_tensor_tensor.h" +#include "aclnnop/aclnn_expand.h" +#include "aclnnop/aclnn_matmul.h" +#include "aclnnop/aclnn_batch_matmul.h" +#include "aclnnop/aclnn_convolution.h" +#include "aclnnop/aclnn_convolution_backward.h" +#include "aclnnop/aclnn_reduce_sum.h" +#include "aclnnop/aclnn_amax.h" +#include "aclnnop/aclnn_amin.h" +#include "aclnnop/aclnn_mean.h" +#include "aclnnop/aclnn_prod.h" +#include "aclnnop/aclnn_triu.h" +#include "aclnnop/aclnn_s_where.h" +#include "aclnnop/aclnn_random.h" +#include "aclnnop/aclnn_normal.h" +#include "aclnnop/aclnn_permute.h" +#include "aclnnop/aclnn_max_pool2d_with_indices.h" +#include "aclnnop/aclnn_max_pool2d_with_indices_backward.h" +#include "aclnnop/aclnn_avgpool2d.h" +#include "aclnnop/aclnn_avgpool2d_backward.h" +#include "aclnnop/aclnn_flip.h" +#include "aclnnop/aclnn_cat.h" +#include "aclnnop/aclnn_gather.h" +#include "aclnnop/aclnn_cumsum.h" +#include "aclnnop/aclnn_index.h" +#include "aclnnop/aclnn_scatter.h" +#include "aclnnop/aclnn_index.h" +#include "aclnnop/aclnn_strided_slice_assign_v2.h" +#include "aclnnop/aclnn_slice_v2.h" +#include "aclnnop/aclnn_index_put_impl.h" +#include "aclnnop/aclnn_range.h" +#include "aclnnop/aclnn_relu.h" +#include "aclnnop/aclnn_dropout.h" +#include "aclnnop/aclnn_dropout_backward.h" +#include "aclnnop/aclnn_leaky_relu.h" +#include "aclnnop/aclnn_leaky_relu_backward.h" +#include "aclnnop/aclnn_uniform.h" +#include "aclnnop/aclnn_silu.h" +#include "aclnnop/aclnn_silu_backward.h" +#include "aclnnop/aclnn_sigmoid.h" +#include "aclnnop/aclnn_sigmoid_backward.h" +#include "aclnnop/aclnn_embedding.h" +#include "aclnnop/aclnn_embedding_dense_backward.h" +#include "aclnnop/aclnn_masked_scatter.h" +#include "aclnnop/aclnn_masked_select.h" +#include "aclnnop/aclnn_split_with_size.h" +#include "aclnnop/aclnn_flash_attention_score.h" +#include "aclnnop/aclnn_flash_attention_score_grad.h" +#include "aclnnop/aclnn_softmax.h" +#include "aclnnop/aclnn_softmax_backward.h" +#include "aclnnop/aclnn_batch_norm.h" +#include "aclnnop/aclnn_batch_norm_backward.h" +#include "aclnnop/aclnn_layer_norm.h" +#include "aclnnop/aclnn_apply_rotary_pos_emb.h" +#include "aclnnop/aclnn_stack.h" +#include "aclnnop/aclnn_nan_to_num.h" + +#define CHECK_RET(cond, return_expr) \ + do \ + { \ + if (!(cond)) \ + { \ + return_expr; \ + } \ + } while (0) + +#define LOG_PRINT(message, ...) \ + do \ + { \ + printf(message, ##__VA_ARGS__); \ + } while (0) + +int64_t GetShapeSize(const std::vector &shape); + +void PrintOutResult(std::vector &shape, void **deviceAddr); + +//int Init(int32_t deviceId); + +/* +template +int CreateAclTensor(const std::vector& hostData, const std::vector& shape, void** deviceAddr, + aclDataType dataType, aclTensor** tensor); +*/ diff --git a/python/jittor/extern/acl/aclops/__init__.py b/python/jittor/extern/acl/aclops/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/jittor/extern/acl/aclops/aclops.h b/python/jittor/extern/acl/aclops/aclops.h new file mode 100644 index 00000000..19a6c5c9 --- /dev/null +++ b/python/jittor/extern/acl/aclops/aclops.h @@ -0,0 +1,33 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/base_op.h b/python/jittor/extern/acl/aclops/base_op.h new file mode 100644 index 00000000..93eef363 --- /dev/null +++ b/python/jittor/extern/acl/aclops/base_op.h @@ -0,0 +1,56 @@ +#pragma once +#include "utils.h" +#include "acl_jittor.h" + +namespace jittor +{ + extern int sync_run; + class BaseOpRunner + { + protected: + vector in_; + vector out_; + + int ret = -1; + uint64_t workspaceSize = 0; + aclOpExecutor *executor; + bool is_group_op = false; + + std::vector> inputShapes; + std::vector> outputShapes; + + std::vector inputTensors; + std::vector outputTensors; + + public: + string name; + string jt_name; + std::unique_ptr op_attr; + bool use_nchw = false; + + BaseOpRunner(const string &name = "") : name(name) {} + virtual ~BaseOpRunner() = default; + + // Common functionality for adding input/output variables + void add(Var *v, bool is_input); + + virtual void setupInputDesc(); + + void cleanupDesc(); + + virtual void setupOutputDesc(); + + virtual void syncRun(); + + void checkRet(aclnnStatus ret); + + // Base run method with common operator lookup logic + void run(); + + protected: + // Virtual method for specific operator execution + virtual void executeOp(std::unordered_map::iterator &it) = 0; + void cleanupAttr(); + }; + +} diff --git a/python/jittor/extern/acl/aclops/base_op_acl.cc b/python/jittor/extern/acl/aclops/base_op_acl.cc new file mode 100644 index 00000000..900300c7 --- /dev/null +++ b/python/jittor/extern/acl/aclops/base_op_acl.cc @@ -0,0 +1,152 @@ +#include +#include +#include +#include +#include +#include +#include +#include "common.h" +#include "op.h" +#include "acl_jittor.h" +#include "ops/random_op.h" +#include "ops/reduce_op.h" +#include "ops/binary_op.h" +#include "ops/broadcast_to_op.h" +#include "ops/transpose_op.h" +#include "ops/array_op.h" +#include "ops/code_op.h" +#include "fused_op.h" +#include "ops/unary_op.h" +#include "ops/ternary_op.h" +#include "executor.h" +#include "misc/cuda_flags.h" +#include "mem/allocator.h" +#include "op_compiler.h" +#include "ops/op_register.h" +#include "opt/tuner_manager.h" +#include "utils/str_utils.h" +#include "aclnn/aclnn.h" +#include "binary_op_acl.h" +#include "base_op.h" + +namespace jittor +{ + extern int sync_run; + // Common functionality for adding input/output variables + void BaseOpRunner::add(Var *v, bool is_input) + { + if (is_input) + { + in_.push_back(v); + } + else + { + out_.push_back(v); + } + return; + } + + void BaseOpRunner::setupInputDesc() + { + auto input_num = in_.size(); + for (int input_idx = 0; input_idx < input_num; input_idx++) + { + std::vector shape; + for (int j = 0; j < in_[input_idx]->shape.size(); j++) + { + shape.push_back(in_[input_idx]->shape[j]); + } + inputShapes.push_back(shape); + } + + for (int idx = 0; idx < input_num; idx++) + { + inputTensors.push_back(nullptr); + auto ret = CreateAclTensor(inputShapes[idx], in_[idx]->mem_ptr, in_[idx]->size, get_dtype(in_[idx]->dtype()), &inputTensors[idx], use_nchw); + CHECK_RET(ret == ACL_SUCCESS, return); + } + } + + void BaseOpRunner::cleanupDesc() + { + auto input_num = in_.size(); + auto output_num = out_.size(); + for (int idx = 0; idx < input_num; idx++) + { + aclDestroyTensor(inputTensors[idx]); + } + for (int idx = 0; idx < output_num; idx++) + { + aclDestroyTensor(outputTensors[idx]); + } + } + + void BaseOpRunner::setupOutputDesc() + { + auto output_num = out_.size(); + + for (int output_idx = 0; output_idx < output_num; output_idx++) + { + std::vector shape; + for (int j = 0; j < out_[output_idx]->shape.size(); j++) + { + shape.push_back(out_[output_idx]->shape[j]); + } + outputShapes.push_back(shape); + } + + for (int idx = 0; idx < output_num; idx++) + { + outputTensors.push_back(nullptr); + auto ret = CreateAclTensor(outputShapes[idx], out_[idx]->mem_ptr, out_[idx]->size, get_dtype(out_[idx]->dtype()), &outputTensors[idx], use_nchw); + CHECK_RET(ret == ACL_SUCCESS, return); + } + } + + void BaseOpRunner::syncRun() + { + if (sync_run) + { + // ret = aclrtSynchronizeStream(aclstream); + // CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclrtSynchronizeStream failed. ERROR: %d\n", name.c_str(), ret); return); + } + } + + void BaseOpRunner::checkRet(aclnnStatus ret) + { + if (ret != ACL_SUCCESS) + { + auto tmp_err_msg = aclGetRecentErrMsg(); + LOGir << name << ", " << tmp_err_msg; + } + + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnxxxGetWorkspaceSize failed. ERROR: %d\n", name.c_str(), ret); return); + } + + // Base run method with common operator lookup logic + void BaseOpRunner::run() + { + if (is_group_op) + { + auto it = aclOpFuncMap.find(name); + if (it == aclOpFuncMap.end()) + { + LOGir << "aclOpFuncMap Not supported op: " << name; + throw std::runtime_error("Unsupported operation type."); + } + setupInputDesc(); + setupOutputDesc(); + executeOp(it); + cleanupDesc(); + } + else + { + auto it = aclOpFuncMap.find(name); + setupInputDesc(); + setupOutputDesc(); + executeOp(it); + cleanupDesc(); + } + } + +} diff --git a/python/jittor/extern/acl/aclops/binary_op_acl.cc b/python/jittor/extern/acl/aclops/binary_op_acl.cc new file mode 100644 index 00000000..18142491 --- /dev/null +++ b/python/jittor/extern/acl/aclops/binary_op_acl.cc @@ -0,0 +1,124 @@ +#include +#include +#include +#include +#include +#include +#include +#include "common.h" +#include "op.h" +#include "acl_jittor.h" +#include "ops/random_op.h" +#include "ops/reduce_op.h" +#include "ops/binary_op.h" +#include "ops/broadcast_to_op.h" +#include "ops/transpose_op.h" +#include "ops/array_op.h" +#include "ops/code_op.h" +#include "fused_op.h" +#include "ops/unary_op.h" +#include "ops/ternary_op.h" +#include "executor.h" +#include "misc/cuda_flags.h" +#include "mem/allocator.h" +#include "op_compiler.h" +#include "ops/op_register.h" +#include "opt/tuner_manager.h" +#include "utils/str_utils.h" +#include "aclnn/aclnn.h" +#include "binary_op_acl.h" + +namespace jittor +{ + BinaryOpRunner::BinaryOpRunner() : BaseOpRunner("binary") + { + use_nchw = false; + is_group_op = true; + } + + void BinaryOpRunner::executeOp(std::unordered_map::iterator &it) + { + aclScalar *alpha = nullptr; + + if (name == string("Add") || name == string("Sub")) + { + if (get_dtype(in_[0]->dtype()) == ACL_FLOAT) + { + float alphaValue = 1.0; + alpha = aclCreateScalar(&alphaValue, get_dtype(in_[0]->dtype())); + } + else if (get_dtype(in_[0]->dtype()) == ACL_FLOAT16) + { + __fp16 alphaValue = 1.0; + alpha = aclCreateScalar(&alphaValue, get_dtype(in_[0]->dtype())); + } + else if (get_dtype(in_[0]->dtype()) == ACL_INT64) + { + int64_t alphaValue = 1; + alpha = aclCreateScalar(&alphaValue, get_dtype(in_[0]->dtype())); + } + else if (get_dtype(in_[0]->dtype()) == ACL_INT32) + { + int alphaValue = 1; + alpha = aclCreateScalar(&alphaValue, get_dtype(in_[0]->dtype())); + } + else if (get_dtype(in_[0]->dtype()) == ACL_INT8) + { + int8_t alphaValue = 1; + alpha = aclCreateScalar(&alphaValue, get_dtype(in_[0]->dtype())); + } + else if (get_dtype(in_[0]->dtype()) == ACL_INT16) + { + int16_t alphaValue = 1; + alpha = aclCreateScalar(&alphaValue, get_dtype(in_[0]->dtype())); + } + else if (get_dtype(in_[0]->dtype()) == ACL_UINT8) + { + uint8_t alphaValue = 1; + alpha = aclCreateScalar(&alphaValue, get_dtype(in_[0]->dtype())); + } + else if (get_dtype(in_[0]->dtype()) == ACL_UINT16) + { + uint16_t alphaValue = 1; + alpha = aclCreateScalar(&alphaValue, get_dtype(in_[0]->dtype())); + } + else if (get_dtype(in_[0]->dtype()) == ACL_UINT32) + { + uint32_t alphaValue = 1; + alpha = aclCreateScalar(&alphaValue, get_dtype(in_[0]->dtype())); + } + else if (get_dtype(in_[0]->dtype()) == ACL_BOOL) + { + bool alphaValue = true; + alpha = aclCreateScalar(&alphaValue, get_dtype(in_[0]->dtype())); + } + else + { + LOGf << "Not supported dtype: " << in_[0]->dtype(); + } + + CHECK_RET(alpha != nullptr, return); + ret = it->second.getWorkspaceSizeFuncAdd(inputTensors[0], inputTensors[1], alpha, outputTensors[0], &workspaceSize, &executor); + } + else + + { + ret = it->second.getWorkspaceSizeFuncBinary(inputTensors[0], inputTensors[1], outputTensors[0], &workspaceSize, &executor); + } + + checkRet(ret); + + if (workspaceSize > 0) + { + mallocWorkSpace(workspaceSize); + } + + ret = it->second.executeFunc(workspaceAddr, workspaceSize, executor, aclstream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnxxx failed. ERROR: %d\n", name.c_str(), ret); return); + + syncRun(); + + aclDestroyScalar(alpha); + return; + } +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/binary_op_acl.h b/python/jittor/extern/acl/aclops/binary_op_acl.h new file mode 100644 index 00000000..82dc4762 --- /dev/null +++ b/python/jittor/extern/acl/aclops/binary_op_acl.h @@ -0,0 +1,14 @@ +#pragma once +#include "utils.h" +#include "base_op.h" + +namespace jittor +{ + struct BinaryOpRunner : public BaseOpRunner + { + BinaryOpRunner(); + + protected: + void executeOp(std::unordered_map::iterator &it) override; + }; +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/bmm_op.py b/python/jittor/extern/acl/aclops/bmm_op.py new file mode 100644 index 00000000..78fee49c --- /dev/null +++ b/python/jittor/extern/acl/aclops/bmm_op.py @@ -0,0 +1,128 @@ +import os +from jittor_utils import env_or_try_find +import jittor_utils +import ctypes +import glob +import jittor.compiler as compiler +import jittor as jt +import math +import numpy as np + +from typing import Union +from collections.abc import Sequence, Iterable + + +def acl_cmd(name: str, + inputs: list, + output_dtypes: list = None, + output_shapes: list = None, + attr_code: str = "", + attr_header: str = "", + outputs: list = None, + extra_data: dict = {}): + attr_header = "\nnamespace jittor{" + attr_header + "}\n" + + cuda_header = ''' + #include "acl/aclops/aclops.h" + ''' + outputs_ = [] + if outputs is not None: + outputs_ = outputs + else: + assert output_dtypes is not None + assert output_shapes is not None + assert len(output_dtypes) == len(output_shapes) + for i in range(len(output_shapes)): + outputs_.append(jt.empty(output_shapes[i], output_dtypes[i])) + input_code = '' + for i in range(len(inputs)): + input_code += f"op.add(in{i}, true);\n" + + return jt.code(outputs=outputs_, + inputs=inputs, + cuda_header=attr_header + cuda_header, + cuda_src=f""" + // aclop + BatchMatMulOpRunner op; + {input_code} + op.add(out0, false); + {attr_code} + op.run();""", + data=extra_data) + + +class BmmACL(jt.Function): + + def __init__(self, trans_x2=False): + super(BmmACL, self).__init__() + self.trans_x2 = trans_x2 + + def execute(self, x1, x2): + self.input = [x1, x2] + result = acl_cmd("BatchMatMul", [x1, x2], + output_dtypes=[x1.dtype], + output_shapes=[ + x1.shape[:-1] + x2.shape[-2:-1] if self.trans_x2 + else x1.shape[:-1] + x2.shape[-1:] + ], + attr_code="op.jt_name=\"bmm_trans_1\";" + if self.trans_x2 else "op.jt_name=\"bmm\";")[0] + + return result + + def grad(self, grad_output): + x1, x2 = self.input + if len(x1) != len(x2): + reshape_grad_x2 = True + else: + reshape_grad_x2 = False + grad_x1 = acl_cmd( + "BatchMatMul", [grad_output, x2], + output_dtypes=[x1.dtype], + output_shapes=[ + grad_output.shape[:-1] + x2.shape[-2:-1] if not self.trans_x2 + else grad_output.shape[:-1] + x1.shape[-1:] + ], + attr_code="op.jt_name=\"bmm_trans_1\";" + if not self.trans_x2 else "op.jt_name=\"bmm\";")[0] + if self.trans_x2: + if reshape_grad_x2: + output_shape = grad_output.shape[1:-2] + grad_output.shape[ + -1:] + x1.shape[-1:] + grad_x2 = acl_cmd("BatchMatMul", [ + grad_output.reshape(-1, grad_output.shape[-1]), + x1.reshape(-1, x1.shape[-1]) + ], + output_dtypes=[x2.dtype], + output_shapes=[output_shape], + attr_code="op.jt_name=\"bmm_trans_0\";")[0] + else: + output_shape = grad_output.shape[:-2] + grad_output.shape[ + -1:] + x1.shape[-1:] + grad_x2 = acl_cmd("BatchMatMul", [grad_output, x1], + output_dtypes=[x2.dtype], + output_shapes=[output_shape], + attr_code="op.jt_name=\"bmm_trans_0\";")[0] + else: + if reshape_grad_x2: + output_shape = x1.shape[1:-2] + x1.shape[ + -1:] + grad_output.shape[-1:] + grad_x2 = acl_cmd("BatchMatMul", [ + x1.reshape(-1, x1.shape[-1]), + grad_output.reshape(-1, grad_output.shape[-1]) + ], + output_dtypes=[x2.dtype], + output_shapes=[output_shape], + attr_code="op.jt_name=\"bmm_trans_0\";")[0] + else: + output_shape = x1.shape[:-2] + x1.shape[ + -1:] + grad_output.shape[-1:] + grad_x2 = acl_cmd("BatchMatMul", [x1, grad_output], + output_dtypes=[x2.dtype], + output_shapes=[output_shape], + attr_code="op.jt_name=\"bmm_trans_0\";")[0] + if len(grad_x1.shape) > len(x1.shape): + grad_x1 = grad_x1.sum(0) + if len(grad_x2.shape) > len(x2.shape): + grad_x2 = grad_x2.sum(0) + return grad_x1, grad_x2 diff --git a/python/jittor/extern/acl/aclops/bmm_op_acl.cc b/python/jittor/extern/acl/aclops/bmm_op_acl.cc new file mode 100644 index 00000000..b75f9177 --- /dev/null +++ b/python/jittor/extern/acl/aclops/bmm_op_acl.cc @@ -0,0 +1,77 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include "common.h" +#include "op.h" +#include "acl_jittor.h" +#include "ops/random_op.h" +#include "ops/reduce_op.h" +#include "ops/binary_op.h" +#include "ops/broadcast_to_op.h" +#include "ops/transpose_op.h" +#include "ops/array_op.h" +#include "ops/code_op.h" +#include "fused_op.h" +#include "ops/unary_op.h" +#include "ops/ternary_op.h" +#include "executor.h" +#include "misc/cuda_flags.h" +#include "mem/allocator.h" +#include "op_compiler.h" +#include "ops/op_register.h" +#include "opt/tuner_manager.h" +#include "utils/str_utils.h" +#include "aclnn/aclnn.h" +#include "bmm_op_acl.h" + +namespace jittor +{ + BatchMatMulOpRunner::BatchMatMulOpRunner() : BaseOpRunner("BatchMatMulMatMul") + { + } + void BatchMatMulOpRunner::setupInputDesc() + { + auto input_num = in_.size(); + for (int input_idx = 0; input_idx < input_num; input_idx++) + { + std::vector shape; + for (int j = 0; j < in_[input_idx]->shape.size(); j++) + { + shape.push_back(in_[input_idx]->shape[j]); + } + inputShapes.push_back(shape); + } + for (int idx = 0; idx < input_num; idx++) + { + inputTensors.push_back(nullptr); + if ((jt_name == "bmm_trans_1" && idx == 1) || (jt_name == "bmm_trans_0" && idx == 0)) + { + auto ret = CreateFakeTransAclTensor(inputShapes[idx], in_[idx]->mem_ptr, in_[idx]->size, get_dtype(in_[idx]->dtype()), &inputTensors[idx], use_nchw); + CHECK_RET(ret == ACL_SUCCESS, return); + } + else + { + auto ret = CreateAclTensor(inputShapes[idx], in_[idx]->mem_ptr, in_[idx]->size, get_dtype(in_[idx]->dtype()), &inputTensors[idx], use_nchw); + CHECK_RET(ret == ACL_SUCCESS, return); + } + } + } + void BatchMatMulOpRunner::executeOp(std::unordered_map::iterator &it) + { + + ret = aclnnBatchMatMulGetWorkspaceSize(inputTensors[0], inputTensors[1], outputTensors[0], 1, &workspaceSize, &executor); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnBatchMatmulGetWorkspaceSize failed. ERROR: %d\n", name.c_str(), ret); return); + if (workspaceSize > 0) + { + mallocWorkSpace(workspaceSize); + } + ret = aclnnBatchMatMul(workspaceAddr, workspaceSize, executor, aclstream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnbatchMatmul failed. ERROR: %d\n", name.c_str(), ret); return); + syncRun(); + } +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/bmm_op_acl.h b/python/jittor/extern/acl/aclops/bmm_op_acl.h new file mode 100644 index 00000000..283bf5fc --- /dev/null +++ b/python/jittor/extern/acl/aclops/bmm_op_acl.h @@ -0,0 +1,17 @@ +#pragma once +#include "utils.h" +#include "base_op.h" + +namespace jittor +{ + class BatchMatMulOpRunner : public BaseOpRunner + { + + protected: + void setupInputDesc() override; + void executeOp(std::unordered_map::iterator &it) override; + + public: + BatchMatMulOpRunner(); + }; +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/concat_op.py b/python/jittor/extern/acl/aclops/concat_op.py new file mode 100644 index 00000000..7a25f79c --- /dev/null +++ b/python/jittor/extern/acl/aclops/concat_op.py @@ -0,0 +1,186 @@ +import os +from jittor_utils import env_or_try_find +import jittor_utils +import ctypes +import glob +import jittor.compiler as compiler +import jittor as jt +import math +import numpy as np + +from typing import Union +from collections.abc import Sequence, Iterable + + +def concat_cmd(name: str, + inputs: list, + output_dtypes: list = None, + output_shapes: list = None, + attr_code: str = "", + attr_header: str = "", + outputs: list = None): + attr_header = "\nnamespace jittor{" + attr_header + "}\n" + + cuda_header = ''' + #include "acl/aclops/aclops.h" + ''' + outputs_ = [] + if outputs is not None: + outputs_ = outputs + else: + assert output_dtypes is not None + assert output_shapes is not None + assert len(output_dtypes) == len(output_shapes) + for i in range(len(output_shapes)): + outputs_.append(jt.empty(output_shapes[i], output_dtypes[i])) + input_code = '' + for i in range(len(inputs)): + input_code += f"op.add(in{i}, true);\n" + + output_code = '' + for i in range(len(outputs_)): + output_code += f"op.add(out{i}, false);\n" + return jt.code(outputs=outputs_, + inputs=inputs, + cuda_header=attr_header + cuda_header, + cuda_src=f""" + + // aclop + {name}OpRunner op; + {input_code} + {output_code} + {attr_code} + op.run();""") + + +class ConcatACL(jt.Function): + + def __init__(self): + super(ConcatACL, self).__init__() + + def __call__(self, *args): + assert isinstance(args[0], (list, tuple)) + assert isinstance(args[1], int) + if jt.flags.no_grad: + return self.execute(*args) + backup = args + args = list(args) + taped_inputs = [] + taped_outputs = [] + input_mask = [-1] * (len(args[0]) + 1) + newargs = [list(), args[1]] + for i, v in enumerate(args[0]): + if isinstance(v, jt.Var): + if v.is_stop_grad(): + # -2 in input_mask represents it is stop_grad + input_mask[i] = -2 + newargs[0].append(v) + continue + v = v.tape() + newargs[0].append(v) + input_mask[i] = len(taped_inputs) + taped_inputs.append(v) + + ori_res = self.execute(*newargs) + if not isinstance(ori_res, Sequence): + res = [ori_res] + else: + res = list(ori_res) + output_mask = [-1] * len(res) + for i, v in enumerate(res): + if isinstance(v, jt.Var): + v = v.tape() + output_mask[i] = len(taped_outputs) + res[i] = v + taped_outputs.append(v) + self.input_mask = input_mask + self.output_mask = output_mask + # tape output and input together so + # backward treat them as one operator + jt.tape_together(taped_inputs, taped_outputs, self._grad) + if isinstance(ori_res, Sequence): + return res + else: + return res[0] + + def execute(self, input_tensors, dim=0): + for _ in input_tensors: + if not (-_.ndim <= dim < _.ndim): + print(_.shape, dim) + raise ValueError("dim out of range") + + if dim < 0: + dim += input_tensors[0].ndim + + self.input = input_tensors + self.dim = dim + for i in range(len(input_tensors)): + if input_tensors[i].dtype != input_tensors[0].dtype: + raise ValueError("All input tensors must have the same dtype") + if input_tensors[i].shape[:dim] != input_tensors[ + 0].shape[:dim] or input_tensors[i].shape[ + dim + 1:] != input_tensors[0].shape[dim + 1:]: + raise ValueError("All input tensors must have the same shape") + attr_code = f""" + op.jt_name = "concat"; + ConcatAttr *attr = new ConcatAttr(); + attr->tensorNum = {len(input_tensors)}; + attr->dim = {dim}; + op.op_attr.reset(attr); + """ + result = concat_cmd( + "Concat", + input_tensors, + output_dtypes=[input_tensors[0].dtype], + output_shapes=[ + jt.empty(self.calculate_output_shape(input_tensors, dim)).shape + ], + attr_code=attr_code)[0] + return result + + def _grad(self, *args): + new_args = ((args[i] if i >= 0 else None) for i in self.output_mask) + ret = self.grad(*new_args) + new_ret = [] + for i, r in enumerate(ret): + j = self.input_mask[i] + if j < 0: + # -2 in input_mask represents it is stop_grad + assert r is None or j==-2, f"{type(self)}'s {i}-th returned grad should be None, "\ + "because the input value is not jittor variable." + else: + new_ret.append(r) + return new_ret + + def grad(self, grad_output): + grad_inputs = self.split_grad(grad_output, self.input, self.dim) + return grad_inputs + + def calculate_output_shape(self, input_tensors, axis): + shape = list(input_tensors[0].shape) + for tensor in input_tensors[1:]: + shape[axis] += tensor.shape[axis] + return tuple(shape) + + def split_grad(self, grad_output, input_tensors, axis): + offset = [] + shapeVec = [] + dtypeVec = [] + for tensor in input_tensors: + offset.append(tensor.shape[axis]) + dtypeVec.append(tensor.dtype) + shapeVec.append(tensor.shape) + + attr_code = f""" + op.jt_name = "splitwithsize"; + auto *attr = new SplitWithSizeAttr(); + attr->splitSize = {{ {", ".join(map(str, offset))} }}; + attr->dim = {axis}; + op.op_attr.reset(attr); + """ + + result = concat_cmd("SplitWithSize", [grad_output], + output_dtypes=dtypeVec, + output_shapes=shapeVec, + attr_code=attr_code) + return result diff --git a/python/jittor/extern/acl/aclops/concat_op_acl.cc b/python/jittor/extern/acl/aclops/concat_op_acl.cc new file mode 100644 index 00000000..8da7caf6 --- /dev/null +++ b/python/jittor/extern/acl/aclops/concat_op_acl.cc @@ -0,0 +1,89 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include "common.h" +#include "op.h" +#include "acl_jittor.h" +#include "ops/random_op.h" +#include "ops/reduce_op.h" +#include "ops/binary_op.h" +#include "ops/broadcast_to_op.h" +#include "ops/transpose_op.h" +#include "ops/array_op.h" +#include "ops/code_op.h" +#include "fused_op.h" +#include "ops/unary_op.h" +#include "ops/ternary_op.h" +#include "executor.h" +#include "misc/cuda_flags.h" +#include "mem/allocator.h" +#include "op_compiler.h" +#include "ops/op_register.h" +#include "opt/tuner_manager.h" +#include "utils/str_utils.h" +#include "aclnn/aclnn.h" +#include "concat_op_acl.h" + +namespace jittor +{ + ConcatOpRunner::ConcatOpRunner() : BaseOpRunner("Concat") + { + } + + void ConcatOpRunner::executeOp(std::unordered_map::iterator &it) + { + auto input_num = in_.size(); + std::vector concatTensorList = {}; + for (int i = 0; i < input_num; i++) + { + concatTensorList.push_back(inputTensors[i]); + } + auto concatTensorListInput = aclCreateTensorList(&concatTensorList[0], input_num); + auto attr = dynamic_cast(op_attr.get()); + ret = aclnnCatGetWorkspaceSize(concatTensorListInput, attr->dim, outputTensors[0], &workspaceSize, &executor); + checkRet(ret); + + if (workspaceSize > 0) + { + mallocWorkSpace(workspaceSize); + } + + ret = aclnnCat(workspaceAddr, workspaceSize, executor, aclstream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnCat failed. ERROR: %d\n", name.c_str(), ret); return); + + syncRun(); + return; + } + + SplitWithSizeOpRunner::SplitWithSizeOpRunner() : BaseOpRunner("SplitWithSize") + { + } + + void SplitWithSizeOpRunner::executeOp(std::unordered_map::iterator &it) + { + auto output_num = out_.size(); + auto attr = dynamic_cast(op_attr.get()); + auto splitSize = aclCreateIntArray(attr->splitSize.data(), attr->splitSize.size()); + auto tensorList = aclCreateTensorList(&outputTensors[0], output_num); + ret = aclnnSplitWithSizeGetWorkspaceSize(inputTensors[0], splitSize, attr->dim, tensorList, &workspaceSize, &executor); + + checkRet(ret); + + if (workspaceSize > 0) + { + mallocWorkSpace(workspaceSize); + } + + ret = aclnnSplitWithSize(workspaceAddr, workspaceSize, executor, aclstream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnSplitWithSize failed. ERROR: %d\n", name.c_str(), ret); return); + + syncRun(); + return; + } + +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/concat_op_acl.h b/python/jittor/extern/acl/aclops/concat_op_acl.h new file mode 100644 index 00000000..a051e343 --- /dev/null +++ b/python/jittor/extern/acl/aclops/concat_op_acl.h @@ -0,0 +1,26 @@ +#pragma once +#include "utils.h" +#include "base_op.h" + +namespace jittor +{ + class ConcatOpRunner : public BaseOpRunner + { + + protected: + void executeOp(std::unordered_map::iterator &it) override; + + public: + ConcatOpRunner(); + }; + + class SplitWithSizeOpRunner : public BaseOpRunner + { + + protected: + void executeOp(std::unordered_map::iterator &it) override; + + public: + SplitWithSizeOpRunner(); + }; +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/conv_op.py b/python/jittor/extern/acl/aclops/conv_op.py new file mode 100644 index 00000000..a0487e21 --- /dev/null +++ b/python/jittor/extern/acl/aclops/conv_op.py @@ -0,0 +1,160 @@ +import os +import jittor_utils +from jittor_utils import env_or_try_find +import ctypes +import glob +import jittor as jt +import jittor.compiler as compiler +import math +import numpy as np + +from typing import Union +from collections.abc import Sequence, Iterable + + +def _ntuple(n): + + def parse(x): + if isinstance(x, Iterable): + return x + return tuple([x] * n) + + return parse + + +_pair = _ntuple(2) + + +def conv_cmd(name: str, + inputs: list, + output_dtypes: list = None, + output_shapes: list = None, + attr_code: str = "", + attr_header: str = "", + outputs: list = None): + attr_header = "\nnamespace jittor{" + attr_header + "}\n" + + cuda_header = ''' + #include "acl/aclops/aclops.h" + ''' + outputs_ = [] + if outputs is not None: + outputs_ = outputs + else: + assert output_dtypes is not None + assert output_shapes is not None + assert len(output_dtypes) == len(output_shapes) + for i in range(len(output_shapes)): + outputs_.append(jt.empty(output_shapes[i], output_dtypes[i])) + input_code = '' + for i in range(len(inputs)): + input_code += f"op.add(in{i}, true);\n" + + output_code = '' + for i in range(len(outputs_)): + output_code += f"op.add(out{i}, false);\n" + return jt.code(outputs=outputs_, + inputs=inputs, + cuda_header=attr_header + cuda_header, + cuda_src=f""" + + // aclop + {name}OpRunner op; + {input_code} + {output_code} + {attr_code} + op.run();""") + +class ConvACL(jt.Function): + + def execute(self, + x, + weight, + bias=None, + stride=1, + padding=0, + dilation=1, + groups=1): + self.input = x + self.weight = weight + self.bias = bias + padding = _pair(padding) + stride = _pair(stride) + dilation = _pair(dilation) + out_channels = weight.shape[0] + if groups <= 0: + raise ValueError("groups must be a positive integer") + self.padding = padding + self.stride = stride + self.dilation = dilation + self.groups = groups + attr_code = f""" + op.jt_name = "conv2d"; + ConvAttr *attr = new ConvAttr(); + attr->convStrides = {{ {stride[0]}, {stride[1]} }}; + attr->convPads = {{ {padding[0]}, {padding[1]} }}; + attr->convDilations = {{ {dilation[0]}, {dilation[1]} }}; + attr->group = {groups}; + attr->convOutPads = {{1,1}}; + op.op_attr.reset(attr); + """ + input_height, input_width = x.shape[-2:] + kernel_height, kernel_width = weight.shape[-2:] + + output_height = (input_height + 2 * padding[0] - dilation[0] * + (kernel_height - 1) - 1) // stride[0] + 1 + output_width = (input_width + 2 * padding[1] - dilation[1] * + (kernel_width - 1) - 1) // stride[1] + 1 + + output_shape = (x.shape[0], out_channels, output_height, output_width) + + inputs = [x, weight] + if bias is not None: + inputs.append(bias) + result = conv_cmd( + "Conv2d", + inputs, + output_dtypes=[x.dtype], + output_shapes=[output_shape], + attr_code=attr_code, + )[0] + return result + + def grad(self, grad_output): + x = self.input + weight = self.weight + bias = self.bias + inputs = [grad_output, x, weight] + if bias is not None: + inputs.append(bias) + output_shapes = [x.shape, weight.shape] + output_dtypes = [x.dtype, weight.dtype] + if bias is not None: + output_shapes.append(bias.shape) + output_dtypes.append(bias.dtype) + else: + output_shapes.append([weight.shape[0]]) + output_dtypes.append(x.dtype) + padding = self.padding + stride = self.stride + dilation = self.dilation + groups = self.groups + attr_code = f""" + op.jt_name = "conv2dbackward"; + ConvAttr *attr = new ConvAttr(); + attr->convStrides = {{ {stride[0]}, {stride[1]} }}; + attr->convPads = {{ {padding[0]}, {padding[1]} }}; + attr->convDilations = {{ {dilation[0]}, {dilation[1]} }}; + attr->group = {groups}; + attr->convOutPads = {{ 1,1}}; + op.op_attr.reset(attr); + """ + results = conv_cmd("Conv2dBackward", + inputs, + output_dtypes=output_dtypes, + output_shapes=output_shapes, + attr_code=attr_code) + if self.bias is None: + return results[0], results[1] + + return results diff --git a/python/jittor/extern/acl/aclops/conv_op_acl.cc b/python/jittor/extern/acl/aclops/conv_op_acl.cc new file mode 100644 index 00000000..ed628696 --- /dev/null +++ b/python/jittor/extern/acl/aclops/conv_op_acl.cc @@ -0,0 +1,152 @@ +#include +#include +#include +#include +#include +#include +#include +#include "common.h" +#include "op.h" +#include "acl_jittor.h" +#include "ops/random_op.h" +#include "ops/reduce_op.h" +#include "ops/binary_op.h" +#include "ops/broadcast_to_op.h" +#include "ops/transpose_op.h" +#include "ops/array_op.h" +#include "ops/code_op.h" +#include "fused_op.h" +#include "ops/unary_op.h" +#include "ops/ternary_op.h" +#include "executor.h" +#include "misc/cuda_flags.h" +#include "mem/allocator.h" +#include "op_compiler.h" +#include "ops/op_register.h" +#include "opt/tuner_manager.h" +#include "utils/str_utils.h" +#include "aclnn/aclnn.h" +#include "conv_op_acl.h" + +namespace jittor +{ + Conv2dOpRunner::Conv2dOpRunner() : BaseOpRunner("Conv2d") + { + use_nchw = true; + } + + void Conv2dOpRunner::executeOp(std::unordered_map::iterator &it) + { + // for conv + aclIntArray *strides = nullptr; + aclIntArray *pads = nullptr; + aclIntArray *outPads = nullptr; + aclIntArray *dilations = nullptr; + auto attr = dynamic_cast(op_attr.get()); + strides = aclCreateIntArray(attr->convStrides.data(), 2); + pads = aclCreateIntArray(attr->convPads.data(), 2); + outPads = aclCreateIntArray(attr->convOutPads.data(), 2); + dilations = aclCreateIntArray(attr->convDilations.data(), 2); + + aclTensor *bias = nullptr; + + auto input_num = in_.size(); + if (input_num == 3) + bias = inputTensors[2]; + + ret = aclnnConvolutionGetWorkspaceSize(inputTensors[0], inputTensors[1], bias, strides, pads, dilations, false, outPads, attr->group, outputTensors[0], 0, &workspaceSize, &executor); + + checkRet(ret); + + if (workspaceSize > 0) + { + mallocWorkSpace(workspaceSize); + } + + ret = aclnnConvolution(workspaceAddr, workspaceSize, executor, aclstream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnConvolution failed. ERROR: %d\n", name.c_str(), ret); return); + + syncRun(); + + aclDestroyIntArray(strides); + aclDestroyIntArray(pads); + aclDestroyIntArray(outPads); + aclDestroyIntArray(dilations); + return; + } + + + Conv2dBackwardOpRunner::Conv2dBackwardOpRunner() : BaseOpRunner("Conv2dBackward") + { + use_nchw = true; + } + + void Conv2dBackwardOpRunner::setupOutputDesc() + { + auto output_num = out_.size(); + + for (int output_idx = 0; output_idx < output_num; output_idx++) + { + std::vector shape; + for (int j = 0; j < out_[output_idx]->shape.size(); j++) + { + shape.push_back(out_[output_idx]->shape[j]); + } + outputShapes.push_back(shape); + } + + for (int idx = 0; idx < 2; idx++) + { + outputTensors.push_back(nullptr); + auto ret = CreateAclTensor(outputShapes[idx], out_[idx]->mem_ptr, out_[idx]->size, get_dtype(out_[idx]->dtype()), &outputTensors[idx], use_nchw); + CHECK_RET(ret == ACL_SUCCESS, return); + } + // biasgrad nd format + { + outputTensors.push_back(nullptr); + auto ret = CreateAclTensor(outputShapes[2], out_[2]->mem_ptr, out_[2]->size, get_dtype(out_[2]->dtype()), &outputTensors[2], false); + CHECK_RET(ret == ACL_SUCCESS, return); + } + } + + void Conv2dBackwardOpRunner::executeOp(std::unordered_map::iterator &it) + { + // for conv + aclIntArray *strides = nullptr; + aclIntArray *pads = nullptr; + aclIntArray *outPads = nullptr; + aclIntArray *dilations = nullptr; + auto attr = dynamic_cast(op_attr.get()); + strides = aclCreateIntArray(attr->convStrides.data(), 2); + pads = aclCreateIntArray(attr->convPads.data(), 2); + outPads = aclCreateIntArray(attr->convOutPads.data(), 2); + dilations = aclCreateIntArray(attr->convDilations.data(), 2); + bool outputMask[3] = {true, true, true}; + auto input_num = in_.size(); + if (input_num == 3) + { + outputMask[2] = false; + } + aclBoolArray *outMask = aclCreateBoolArray(outputMask, 3); + auto biasSizes = aclCreateIntArray(&outputShapes[2][0], outputShapes[2].size()); + ret = aclnnConvolutionBackwardGetWorkspaceSize(inputTensors[0], inputTensors[1], inputTensors[2], biasSizes, strides, pads, dilations, false, outPads, attr->group, outMask, 0, outputTensors[0], outputTensors[1], outputTensors[2], &workspaceSize, &executor); + + checkRet(ret); + + if (workspaceSize > 0) + { + mallocWorkSpace(workspaceSize); + } + + ret = aclnnConvolutionBackward(workspaceAddr, workspaceSize, executor, aclstream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnConvolutionBackward failed. ERROR: %d\n", name.c_str(), ret); return); + + syncRun(); + + aclDestroyIntArray(strides); + aclDestroyIntArray(pads); + aclDestroyIntArray(outPads); + aclDestroyIntArray(dilations); + return; + } +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/conv_op_acl.h b/python/jittor/extern/acl/aclops/conv_op_acl.h new file mode 100644 index 00000000..2054f3a2 --- /dev/null +++ b/python/jittor/extern/acl/aclops/conv_op_acl.h @@ -0,0 +1,27 @@ +#pragma once +#include "utils.h" +#include "base_op.h" + +namespace jittor +{ + class Conv2dOpRunner : public BaseOpRunner + { + + protected: + void executeOp(std::unordered_map::iterator &it) override; + + public: + Conv2dOpRunner(); + }; + + class Conv2dBackwardOpRunner : public BaseOpRunner + { + + protected: + void executeOp(std::unordered_map::iterator &it) override; + void setupOutputDesc() override; + + public: + Conv2dBackwardOpRunner(); + }; +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/cumsum_op.py b/python/jittor/extern/acl/aclops/cumsum_op.py new file mode 100644 index 00000000..28ce48b8 --- /dev/null +++ b/python/jittor/extern/acl/aclops/cumsum_op.py @@ -0,0 +1,101 @@ +import os +from jittor_utils import env_or_try_find +import jittor_utils +import ctypes +import glob +import jittor.compiler as compiler +import jittor as jt +import math +import numpy as np + +from typing import Union +from collections.abc import Sequence, Iterable + + +def cumsum_cmd(name: str, + inputs: list, + output_dtypes: list = None, + output_shapes: list = None, + attr_code: str = "", + attr_header: str = "", + outputs: list = None): + attr_header = "\nnamespace jittor{" + attr_header + "}\n" + + cuda_header = ''' + #include "acl/aclops/aclops.h" + ''' + outputs_ = [] + if outputs is not None: + outputs_ = outputs + else: + assert output_dtypes is not None + assert output_shapes is not None + assert len(output_dtypes) == len(output_shapes) + for i in range(len(output_shapes)): + outputs_.append(jt.empty(output_shapes[i], output_dtypes[i])) + input_code = '' + for i in range(len(inputs)): + input_code += f"op.add(in{i}, true);\n" + + output_code = '' + for i in range(len(outputs_)): + output_code += f"op.add(out{i}, false);\n" + return jt.code(outputs=outputs_, + inputs=inputs, + cuda_header=attr_header + cuda_header, + cuda_src=f""" + + // aclop + {name}OpRunner op; + {input_code} + {output_code} + {attr_code} + op.run();""") + + +class CumsumACL(jt.Function): + + def __init__(self): + super(CumsumACL, self).__init__() + + def execute(self, input, dim=-1): + self.dim = dim + attr_code = f""" + op.jt_name = "cumsum"; + GatherAttr *attr = new GatherAttr(); + attr->dim = {dim}; + op.op_attr.reset(attr); + """ + result = cumsum_cmd("Cumsum", [input], + output_dtypes=[input.dtype], + output_shapes=[input.shape], + attr_code=attr_code)[0] + return result + + def grad(self, grad_output): + cumsum_attr_code = f""" + op.jt_name = "cumsum"; + GatherAttr *attr = new GatherAttr(); + attr->dim = {self.dim}; + op.op_attr.reset(attr); + """ + flip_attr_code = f""" + op.jt_name = "flip"; + ReduceAttr *attr = new ReduceAttr(); + attr->axes = {{{self.dim}}}; + attr->prod_dim = {{{1}}}; + op.op_attr.reset(attr); + """ + flipped_grad_output = cumsum_cmd("Flip", [grad_output], + output_dtypes=[grad_output.dtype], + output_shapes=[grad_output.shape], + attr_code=flip_attr_code)[0] + cumulative_grad = cumsum_cmd("Cumsum", [flipped_grad_output], + output_dtypes=[grad_output.dtype], + output_shapes=[grad_output.shape], + attr_code=cumsum_attr_code)[0] + grad_input = cumsum_cmd("Flip", [cumulative_grad], + output_dtypes=[grad_output.dtype], + output_shapes=[grad_output.shape], + attr_code=flip_attr_code)[0] + return grad_input diff --git a/python/jittor/extern/acl/aclops/cumsum_op_acl.cc b/python/jittor/extern/acl/aclops/cumsum_op_acl.cc new file mode 100644 index 00000000..4d11a0c2 --- /dev/null +++ b/python/jittor/extern/acl/aclops/cumsum_op_acl.cc @@ -0,0 +1,57 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include "common.h" +#include "op.h" +#include "acl_jittor.h" +#include "ops/random_op.h" +#include "ops/reduce_op.h" +#include "ops/binary_op.h" +#include "ops/broadcast_to_op.h" +#include "ops/transpose_op.h" +#include "ops/array_op.h" +#include "ops/code_op.h" +#include "fused_op.h" +#include "ops/unary_op.h" +#include "ops/ternary_op.h" +#include "executor.h" +#include "misc/cuda_flags.h" +#include "mem/allocator.h" +#include "op_compiler.h" +#include "ops/op_register.h" +#include "opt/tuner_manager.h" +#include "utils/str_utils.h" +#include "aclnn/aclnn.h" +#include "cumsum_op_acl.h" + +namespace jittor +{ + CumsumOpRunner::CumsumOpRunner() : BaseOpRunner("Cumsum") + { + } + + void CumsumOpRunner::executeOp(std::unordered_map::iterator &it) + { + auto attr = dynamic_cast(op_attr.get()); + ret = aclnnCumsumGetWorkspaceSize(inputTensors[0], attr->dim, get_dtype(out_[0]->dtype()), outputTensors[0], &workspaceSize, &executor); + + checkRet(ret); + + if (workspaceSize > 0) + { + mallocWorkSpace(workspaceSize); + } + + ret = aclnnCumsum(workspaceAddr, workspaceSize, executor, aclstream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnCumsum failed. ERROR: %d\n", name.c_str(), ret); return); + + syncRun(); + return; + } + +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/cumsum_op_acl.h b/python/jittor/extern/acl/aclops/cumsum_op_acl.h new file mode 100644 index 00000000..1b9888f1 --- /dev/null +++ b/python/jittor/extern/acl/aclops/cumsum_op_acl.h @@ -0,0 +1,17 @@ +#pragma once +#include "utils.h" +#include "base_op.h" + +namespace jittor +{ + class CumsumOpRunner : public BaseOpRunner + { + + protected: + void executeOp(std::unordered_map::iterator &it) override; + + public: + CumsumOpRunner(); + }; + +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/dropout_op.py b/python/jittor/extern/acl/aclops/dropout_op.py new file mode 100644 index 00000000..c7f3327a --- /dev/null +++ b/python/jittor/extern/acl/aclops/dropout_op.py @@ -0,0 +1,94 @@ +import os +from jittor_utils import env_or_try_find +import jittor_utils +import ctypes +import glob +import jittor.compiler as compiler +import jittor as jt +import math +import numpy as np + +from typing import Union +from collections.abc import Sequence, Iterable + + +def dropout_cmd(name: str, + inputs: list, + output_dtypes: list = None, + output_shapes: list = None, + attr_code: str = "", + attr_header: str = "", + outputs: list = None): + attr_header = "\nnamespace jittor{" + attr_header + "}\n" + + cuda_header = ''' + #include "acl/aclops/aclops.h" + ''' + outputs_ = [] + if outputs is not None: + outputs_ = outputs + else: + assert output_dtypes is not None + assert output_shapes is not None + assert len(output_dtypes) == len(output_shapes) + for i in range(len(output_shapes)): + outputs_.append(jt.empty(output_shapes[i], output_dtypes[i])) + input_code = '' + for i in range(len(inputs)): + input_code += f"op.add(in{i}, true);\n" + + output_code = '' + for i in range(len(outputs_)): + output_code += f"op.add(out{i}, false);\n" + return jt.code(outputs=outputs_, + inputs=inputs, + cuda_header=attr_header + cuda_header, + cuda_src=f""" + + // aclop + {name}OpRunner op; + {input_code} + {output_code} + {attr_code} + op.run();""") + + +class DropoutACL(jt.Function): + + def __init__(self): + super(DropoutACL, self).__init__() + + def execute(self, x, p=0.5, is_train=False): + self.input = x + num_elements = x.numel() + aligned_elements = (num_elements + 127) // 128 * 128 + mask_shape = (aligned_elements // 8, ) + attr_code = f""" + op.jt_name = "dropout"; + DropoutAttr *attr = new DropoutAttr(); + attr->p = {p}; + attr->train = {"true" if is_train else "false"}; + attr->seed = 0; + attr->offset = 0; + op.op_attr.reset(attr); + """ + result = dropout_cmd("Dropout", [x], + output_dtypes=[x.dtype, "uint8"], + output_shapes=[x.shape, mask_shape], + attr_code=attr_code) + self.maskout = result[1] + return result[0] + + def grad(self, grad_output): + attr_code = f""" + op.jt_name = "dropoutbackward"; + DropoutAttr *attr = new DropoutAttr(); + attr->scale = 1.0; + op.op_attr.reset(attr); + """ + grad_input = dropout_cmd("DropoutBackward", + [grad_output, self.maskout], + output_dtypes=[grad_output.dtype], + output_shapes=[grad_output.shape], + attr_code=attr_code)[0] + return grad_input diff --git a/python/jittor/extern/acl/aclops/dropout_op_acl.cc b/python/jittor/extern/acl/aclops/dropout_op_acl.cc new file mode 100644 index 00000000..9d413b87 --- /dev/null +++ b/python/jittor/extern/acl/aclops/dropout_op_acl.cc @@ -0,0 +1,82 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include "common.h" +#include "op.h" +#include "acl_jittor.h" +#include "ops/random_op.h" +#include "ops/reduce_op.h" +#include "ops/binary_op.h" +#include "ops/broadcast_to_op.h" +#include "ops/transpose_op.h" +#include "ops/array_op.h" +#include "ops/code_op.h" +#include "fused_op.h" +#include "ops/unary_op.h" +#include "ops/ternary_op.h" +#include "executor.h" +#include "misc/cuda_flags.h" +#include "mem/allocator.h" +#include "op_compiler.h" +#include "ops/op_register.h" +#include "opt/tuner_manager.h" +#include "utils/str_utils.h" +#include "aclnn/aclnn.h" +#include "dropout_op_acl.h" + +namespace jittor +{ + DropoutOpRunner::DropoutOpRunner() : BaseOpRunner("Dropout") + { + } + + void DropoutOpRunner::executeOp(std::unordered_map::iterator &it) + { + auto attr = dynamic_cast(op_attr.get()); + ret = aclnnDropoutGetWorkspaceSize(inputTensors[0], attr->p, attr->train, attr->seed, attr->offset, outputTensors[0], outputTensors[1], &workspaceSize, &executor); + + checkRet(ret); + + if (workspaceSize > 0) + { + mallocWorkSpace(workspaceSize); + } + + ret = aclnnDropout(workspaceAddr, workspaceSize, executor, aclstream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnDropout failed. ERROR: %d\n", name.c_str(), ret); return); + + syncRun(); + + return; + } + + DropoutBackwardOpRunner::DropoutBackwardOpRunner() : BaseOpRunner("DropoutBackward") + { + } + + void DropoutBackwardOpRunner::executeOp(std::unordered_map::iterator &it) + { + auto attr = dynamic_cast(op_attr.get()); + ret = aclnnDropoutBackwardGetWorkspaceSize(inputTensors[0], inputTensors[1], attr->scale, outputTensors[0], &workspaceSize, &executor); + + checkRet(ret); + + if (workspaceSize > 0) + { + mallocWorkSpace(workspaceSize); + } + + ret = aclnnDropoutBackward(workspaceAddr, workspaceSize, executor, aclstream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnDropoutBackward failed. ERROR: %d\n", name.c_str(), ret); return); + + syncRun(); + + return; + } + +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/dropout_op_acl.h b/python/jittor/extern/acl/aclops/dropout_op_acl.h new file mode 100644 index 00000000..3380b0ec --- /dev/null +++ b/python/jittor/extern/acl/aclops/dropout_op_acl.h @@ -0,0 +1,27 @@ +#pragma once +#include "utils.h" +#include "base_op.h" + +namespace jittor +{ + class DropoutOpRunner : public BaseOpRunner + { + + protected: + void executeOp(std::unordered_map::iterator &it) override; + + public: + DropoutOpRunner(); + }; + + class DropoutBackwardOpRunner : public BaseOpRunner + { + + protected: + void executeOp(std::unordered_map::iterator &it) override; + + public: + DropoutBackwardOpRunner(); + }; + +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/embedding_op.py b/python/jittor/extern/acl/aclops/embedding_op.py new file mode 100644 index 00000000..9f7156d3 --- /dev/null +++ b/python/jittor/extern/acl/aclops/embedding_op.py @@ -0,0 +1,91 @@ +import os +from jittor_utils import env_or_try_find +import jittor_utils +import ctypes +import glob +import jittor.compiler as compiler +import jittor as jt +import math +import numpy as np + +from typing import Union +from collections.abc import Sequence, Iterable + +def embedding_cmd(name: str, + inputs: list, + output_dtypes: list = None, + output_shapes: list = None, + attr_code: str = "", + attr_header: str = "", + outputs: list = None): + attr_header = "\nnamespace jittor{" + attr_header + "}\n" + + cuda_header = ''' + #include "acl/aclops/aclops.h" + ''' + outputs_ = [] + if outputs is not None: + outputs_ = outputs + else: + assert output_dtypes is not None + assert output_shapes is not None + assert len(output_dtypes) == len(output_shapes) + for i in range(len(output_shapes)): + outputs_.append(jt.empty(output_shapes[i], output_dtypes[i])) + input_code = '' + for i in range(len(inputs)): + input_code += f"op.add(in{i}, true);\n" + + output_code = '' + for i in range(len(outputs_)): + output_code += f"op.add(out{i}, false);\n" + return jt.code(outputs=outputs_, + inputs=inputs, + cuda_header=attr_header + cuda_header, + cuda_src=f""" + + // aclop + {name}OpRunner op; + {input_code} + {output_code} + {attr_code} + op.run();""") + +class EmbeddingACL(jt.Function): + + def __init__(self): + super(EmbeddingACL, self).__init__() + + def execute( + self, + indices, + weight, + ): + inputs = [weight, indices] + self.indices = indices + self.weight_shape = weight.shape + output_shape = list(indices.shape) + list(weight.shape[1:]) + outputs = [jt.empty(output_shape, weight.dtype)] + attr_code = f""" + op.jt_name = "embedding"; + """ + result = embedding_cmd("Embedding", + inputs=inputs, + outputs=outputs, + attr_code=attr_code)[0] + return result + + def grad(self, grad_output): + inputs = [grad_output, self.indices] + outputs = [jt.empty(self.weight_shape, grad_output.dtype)] + attr_code = f""" + op.jt_name = "embeddingbackward"; + EmbeddingAttr *attr = new EmbeddingAttr(); + attr->numEmbeddings = {self.weight_shape[0]}; + op.op_attr.reset(attr); + """ + grad_weight = embedding_cmd("EmbeddingBackward", + inputs=inputs, + outputs=outputs, + attr_code=attr_code)[0] + return None, grad_weight diff --git a/python/jittor/extern/acl/aclops/embedding_op_acl.cc b/python/jittor/extern/acl/aclops/embedding_op_acl.cc new file mode 100644 index 00000000..8cb75521 --- /dev/null +++ b/python/jittor/extern/acl/aclops/embedding_op_acl.cc @@ -0,0 +1,82 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include "common.h" +#include "op.h" +#include "acl_jittor.h" +#include "ops/random_op.h" +#include "ops/reduce_op.h" +#include "ops/binary_op.h" +#include "ops/broadcast_to_op.h" +#include "ops/transpose_op.h" +#include "ops/array_op.h" +#include "ops/code_op.h" +#include "fused_op.h" +#include "ops/unary_op.h" +#include "ops/ternary_op.h" +#include "executor.h" +#include "misc/cuda_flags.h" +#include "mem/allocator.h" +#include "op_compiler.h" +#include "ops/op_register.h" +#include "opt/tuner_manager.h" +#include "utils/str_utils.h" +#include "aclnn/aclnn.h" +#include "embedding_op_acl.h" + +namespace jittor +{ + EmbeddingOpRunner::EmbeddingOpRunner() : BaseOpRunner("Embedding") + { + } + + void EmbeddingOpRunner::executeOp(std::unordered_map::iterator &it) + { + ret = aclnnEmbeddingGetWorkspaceSize(inputTensors[0], inputTensors[1], outputTensors[0], &workspaceSize, &executor); + + checkRet(ret); + + if (workspaceSize > 0) + { + mallocWorkSpace(workspaceSize); + } + + ret = aclnnEmbedding(workspaceAddr, workspaceSize, executor, aclstream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnEmbedding failed. ERROR: %d\n", name.c_str(), ret); return); + + syncRun(); + + return; + } + + EmbeddingBackwardOpRunner::EmbeddingBackwardOpRunner() : BaseOpRunner("EmbeddingBackward") + { + } + + void EmbeddingBackwardOpRunner::executeOp(std::unordered_map::iterator &it) + { + auto attr = dynamic_cast(op_attr.get()); + auto numEmbeddings = attr->numEmbeddings; + ret = aclnnEmbeddingDenseBackwardGetWorkspaceSize(inputTensors[0], inputTensors[1], numEmbeddings, 0, false, outputTensors[0], &workspaceSize, &executor); + + checkRet(ret); + + if (workspaceSize > 0) + { + mallocWorkSpace(workspaceSize); + } + + ret = aclnnEmbeddingDenseBackward(workspaceAddr, workspaceSize, executor, aclstream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnEmbeddingDenseBackward failed. ERROR: %d\n", name.c_str(), ret); return); + + syncRun(); + + return; + } + +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/embedding_op_acl.h b/python/jittor/extern/acl/aclops/embedding_op_acl.h new file mode 100644 index 00000000..37e9f69e --- /dev/null +++ b/python/jittor/extern/acl/aclops/embedding_op_acl.h @@ -0,0 +1,25 @@ +#pragma once +#include "utils.h" +#include "base_op.h" + +namespace jittor +{ + class EmbeddingOpRunner : public BaseOpRunner + { + + protected: + void executeOp(std::unordered_map::iterator &it) override; + public: + EmbeddingOpRunner(); + }; + + class EmbeddingBackwardOpRunner : public BaseOpRunner + { + + protected: + void executeOp(std::unordered_map::iterator &it) override; + public: + EmbeddingBackwardOpRunner(); + }; + +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/expand_op_acl.cc b/python/jittor/extern/acl/aclops/expand_op_acl.cc new file mode 100644 index 00000000..0329b5f1 --- /dev/null +++ b/python/jittor/extern/acl/aclops/expand_op_acl.cc @@ -0,0 +1,58 @@ +#include +#include +#include +#include +#include +#include +#include +#include "common.h" +#include "op.h" +#include "acl_jittor.h" +#include "ops/random_op.h" +#include "ops/reduce_op.h" +#include "ops/binary_op.h" +#include "ops/broadcast_to_op.h" +#include "ops/transpose_op.h" +#include "ops/array_op.h" +#include "ops/code_op.h" +#include "fused_op.h" +#include "ops/unary_op.h" +#include "ops/ternary_op.h" +#include "executor.h" +#include "misc/cuda_flags.h" +#include "mem/allocator.h" +#include "op_compiler.h" +#include "ops/op_register.h" +#include "opt/tuner_manager.h" +#include "utils/str_utils.h" +#include "aclnn/aclnn.h" +#include "expand_op_acl.h" + +namespace jittor +{ + ExpandOpRunner::ExpandOpRunner() : BaseOpRunner("ternary") + { + use_nchw = false; + } + + void ExpandOpRunner::executeOp(std::unordered_map::iterator &it) + { + aclIntArray *size = nullptr; + size = aclCreateIntArray(&outputShapes[0][0], outputShapes[0].size()); + ret = aclnnExpandGetWorkspaceSize(inputTensors[0], size, outputTensors[0], &workspaceSize, &executor); + + checkRet(ret); + + if (workspaceSize > 0) + { + mallocWorkSpace(workspaceSize); + } + + ret = aclnnExpand(workspaceAddr, workspaceSize, executor, aclstream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnExpand failed. ERROR: %d\n", name.c_str(), ret); return); + + aclDestroyIntArray(size); + + return; + } +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/expand_op_acl.h b/python/jittor/extern/acl/aclops/expand_op_acl.h new file mode 100644 index 00000000..1026c481 --- /dev/null +++ b/python/jittor/extern/acl/aclops/expand_op_acl.h @@ -0,0 +1,14 @@ +#pragma once +#include "utils.h" +#include "base_op.h" + +namespace jittor +{ + struct ExpandOpRunner : public BaseOpRunner + { + ExpandOpRunner(); + + protected: + void executeOp(std::unordered_map::iterator &it) override; + }; +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/flashattention_op.py b/python/jittor/extern/acl/aclops/flashattention_op.py new file mode 100644 index 00000000..ce220708 --- /dev/null +++ b/python/jittor/extern/acl/aclops/flashattention_op.py @@ -0,0 +1,209 @@ +import os +from jittor_utils import env_or_try_find +import jittor_utils +import ctypes +import glob +import jittor.compiler as compiler +import jittor as jt +import math +import numpy as np + +from typing import Union +from collections.abc import Sequence, Iterable + + +def flashattention_cmd(name: str, + inputs: list, + output_dtypes: list = None, + output_shapes: list = None, + attr_code: str = "", + attr_header: str = "", + outputs: list = None): + attr_header = "\nnamespace jittor{" + attr_header + "}\n" + + cuda_header = ''' + #include "acl/aclops/aclops.h" + ''' + outputs_ = [] + if outputs is not None: + outputs_ = outputs + else: + assert output_dtypes is not None + assert output_shapes is not None + assert len(output_dtypes) == len(output_shapes) + for i in range(len(output_shapes)): + outputs_.append(jt.empty(output_shapes[i], output_dtypes[i])) + input_code = '' + for i in range(len(inputs)): + input_code += f"op.add(in{i}, true);\n" + + output_code = '' + for i in range(len(outputs_)): + output_code += f"op.add(out{i}, false);\n" + return jt.code(outputs=outputs_, + inputs=inputs, + cuda_header=attr_header + cuda_header, + cuda_src=f""" + + // aclop + {name}OpRunner op; + {input_code} + {output_code} + {attr_code} + op.run();""") + + +class FlashAttentionACL(jt.Function): + + def __init__(self, + headnum, + layout="BNSD", + prefix=None, + qstart=None, + kvstart=None, + scale=1.0, + prob=1.0, + pretokens=2147483647, + nexttokens=2147483647, + innerprecise=0, + sparsemode=0, + psetype=1): + self.headnum = headnum + self.layout = layout + self.scale = scale + self.prob = prob + self.pretokens = pretokens + self.nexttokens = nexttokens + self.innerprecise = innerprecise + self.sparsemode = sparsemode + self.psetype = psetype + self.prefix = prefix + self.qstart = qstart + self.kvstart = kvstart + + def execute( + self, + q, + k, + v, + realshift=None, + dropMask=None, + paddingMask=None, + attenMask=None, + ): + if self.layout == 'BSH': + B, SQ, H = q.shape + SKV = k.shape[1] + N = self.headnum + D = H / N + elif self.layout == 'SBH': + SQ, B, H = q.shape + SKV = k.shape[0] + N = self.headnum + D = H / N + elif self.layout == 'BSND': + B, SQ, N, D = q.shape + SKV = k.shape[1] + elif self.layout == 'BNSD': + B, N, SQ, D = q.shape + SKV = k.shape[2] + else: + raise ValueError(f"got invalid input layout {self.layout}") + + output_shape = (B, N, SQ, 8) + + self.q = q + self.k = k + self.v = v + + self.prefix = self.prefix if self.prefix else [0 for _ in range(B)] + self.qstart = self.qstart if self.qstart else [0 for _ in range(B)] + self.kvstart = self.kvstart if self.kvstart else [0 for _ in range(B)] + + self.hasRealshift = (not realshift == None) + self.hasDropmask = (not dropMask == None) + self.hasPaddingmask = (not paddingMask == None) + self.hasAttenmask = (not attenMask == None) + + # 待定,目前设为nullptr + self.realshift = realshift if realshift else jt.zeros(B, N, SQ, SKV) + self.dropMask = dropMask if dropMask else jt.ones(B, N, SQ, SKV) + self.paddingMask = paddingMask if paddingMask else jt.zeros( + B, N, SQ, SKV) + self.attenMask = attenMask if attenMask else jt.zeros(SQ, SKV) + + attr_code = f""" + op.jt_name = "flashattention"; + FlashAttentionAttr *attr = new FlashAttentionAttr(); + attr->scale = {self.scale}; + attr->keepProb = {self.prob}; + attr->preToken = {self.pretokens}; + attr->nextToken = {self.nexttokens}; + attr->headNum = {self.headnum}; + attr->inputLayout = "{self.layout}"; + attr->innerPrecise = {self.innerprecise}; + attr->sparseMode = {self.sparsemode}; + attr->psetype = {self.psetype}; + attr->prefix = {{ {", ".join(map(str, self.prefix))} }}; + attr->qStartIdx = {{ {", ".join(map(str, self.qstart))} }}; + attr->kvStartIdx = {{ {", ".join(map(str, self.kvstart))} }}; + attr->hasRealshift = {"true" if self.hasRealshift else "false"}; + attr->hasDropmask = {"true" if self.hasDropmask else "false"}; + attr->hasPaddingmask = {"true" if self.hasPaddingmask else "false"}; + attr->hasAttentmask = {"true" if self.hasAttenmask else "false"}; + op.op_attr.reset(attr); + """ + + inputs = [ + q, k, v, self.realshift, self.dropMask, self.paddingMask, + self.attenMask + ] + + result = flashattention_cmd( + "FlashAttention", + inputs, + output_dtypes=["float", "float", q.dtype], + output_shapes=[output_shape, output_shape, q.shape], + attr_code=attr_code) + + self.maxout = result[0] + self.sumout = result[1] + self.attenout = result[2] + + return self.attenout + + def grad(self, dy): + attr_code = f""" + op.jt_name = "flashattentionbackward"; + FlashAttentionAttr *attr = new FlashAttentionAttr(); + attr->scale = {self.scale}; + attr->keepProb = {self.prob}; + attr->preToken = {self.pretokens}; + attr->nextToken = {self.nexttokens}; + attr->headNum = {self.headnum}; + attr->inputLayout = "{self.layout}"; + attr->innerPrecise = {self.innerprecise}; + attr->sparseMode = {self.sparsemode}; + attr->psetype = {self.psetype}; + attr->prefix = {{ {", ".join(map(str, self.prefix))} }}; + attr->qStartIdx = {{ {", ".join(map(str, self.qstart))} }}; + attr->kvStartIdx = {{ {", ".join(map(str, self.kvstart))} }}; + attr->hasRealshift = {"true" if self.hasRealshift else "false"}; + attr->hasDropmask = {"true" if self.hasDropmask else "false"}; + attr->hasPaddingmask = {"true" if self.hasPaddingmask else "false"}; + attr->hasAttentmask = {"true" if self.hasAttenmask else "false"}; + op.op_attr.reset(attr); + """ + inputs = [ + self.q, self.k, self.v, dy, self.realshift, self.dropMask, + self.paddingMask, self.attenMask, self.maxout, self.sumout, + self.attenout + ] + + result = flashattention_cmd( + "FlashAttentionBackward", + inputs, + output_dtypes=[self.q.dtype, self.k.dtype, self.v.dtype], + output_shapes=[self.q.shape, self.k.shape, self.v.shape], + attr_code=attr_code) + return result diff --git a/python/jittor/extern/acl/aclops/flashattention_op_acl.cc b/python/jittor/extern/acl/aclops/flashattention_op_acl.cc new file mode 100644 index 00000000..43a71ab7 --- /dev/null +++ b/python/jittor/extern/acl/aclops/flashattention_op_acl.cc @@ -0,0 +1,88 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include "common.h" +#include "op.h" +#include "acl_jittor.h" +#include "ops/random_op.h" +#include "ops/reduce_op.h" +#include "ops/binary_op.h" +#include "ops/broadcast_to_op.h" +#include "ops/transpose_op.h" +#include "ops/array_op.h" +#include "ops/code_op.h" +#include "fused_op.h" +#include "ops/unary_op.h" +#include "ops/ternary_op.h" +#include "executor.h" +#include "misc/cuda_flags.h" +#include "mem/allocator.h" +#include "op_compiler.h" +#include "ops/op_register.h" +#include "opt/tuner_manager.h" +#include "utils/str_utils.h" +#include "aclnn/aclnn.h" +#include "flashattention_op_acl.h" + +namespace jittor +{ + FlashAttentionOpRunner::FlashAttentionOpRunner() : BaseOpRunner("FlashAttention") + { + } + + void FlashAttentionOpRunner::executeOp(std::unordered_map::iterator &it) + { + auto attr = dynamic_cast(op_attr.get()); + auto prefix = aclCreateIntArray(attr->prefix.data(), attr->prefix.size()); + auto qstart = aclCreateIntArray(attr->qStartIdx.data(), attr->qStartIdx.size()); + auto kvstart = aclCreateIntArray(attr->kvStartIdx.data(), attr->kvStartIdx.size()); + char *layout = const_cast(attr->inputLayout.data()); + ret = aclnnFlashAttentionScoreV2GetWorkspaceSize(inputTensors[0], inputTensors[1], inputTensors[2], attr->hasRealshift ? inputTensors[3] : nullptr, attr->hasDropmask ? inputTensors[4] : nullptr, nullptr, attr->hasAttentmask ? inputTensors[6] : nullptr, prefix, qstart, kvstart, attr->scale, attr->keepProb, attr->preToken, attr->nextToken, attr->headNum, layout, attr->innerPrecise, attr->sparseMode, attr->psetype, outputTensors[0], outputTensors[1], nullptr, outputTensors[2], &workspaceSize, &executor); + + checkRet(ret); + + if (workspaceSize > 0) + { + mallocWorkSpace(workspaceSize); + } + + ret = aclnnFlashAttentionScoreV2(workspaceAddr, workspaceSize, executor, aclstream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnFlashAttentionScoreV2 failed. ERROR: %d\n", name.c_str(), ret); return); + + syncRun(); + return; + } + + FlashAttentionBackwardOpRunner::FlashAttentionBackwardOpRunner() : BaseOpRunner("FlashAttentionBackward") + { + } + + void FlashAttentionBackwardOpRunner::executeOp(std::unordered_map::iterator &it) + { + auto attr = dynamic_cast(op_attr.get()); + auto prefix = aclCreateIntArray(attr->prefix.data(), attr->prefix.size()); + auto qstart = aclCreateIntArray(attr->qStartIdx.data(), attr->qStartIdx.size()); + auto kvstart = aclCreateIntArray(attr->kvStartIdx.data(), attr->kvStartIdx.size()); + char *layout = const_cast(attr->inputLayout.data()); + ret = aclnnFlashAttentionScoreGradV2GetWorkspaceSize(inputTensors[0], inputTensors[1], inputTensors[2], inputTensors[3], attr->hasRealshift ? inputTensors[4] : nullptr, attr->hasDropmask ? inputTensors[5] : nullptr, nullptr, attr->hasAttentmask ? inputTensors[7] : nullptr, inputTensors[8], inputTensors[9], nullptr, inputTensors[10], prefix, qstart, kvstart, attr->scale, attr->keepProb, attr->preToken, attr->nextToken, attr->headNum, layout, attr->innerPrecise, attr->sparseMode, attr->psetype, outputTensors[0], outputTensors[1], outputTensors[2], nullptr, &workspaceSize, &executor); + + checkRet(ret); + + if (workspaceSize > 0) + { + mallocWorkSpace(workspaceSize); + } + + ret = aclnnFlashAttentionScoreGradV2(workspaceAddr, workspaceSize, executor, aclstream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnFlashAttentionScoreGradV2 failed. ERROR: %d\n", name.c_str(), ret); return); + + syncRun(); + return; + } + +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/flashattention_op_acl.h b/python/jittor/extern/acl/aclops/flashattention_op_acl.h new file mode 100644 index 00000000..16c02caa --- /dev/null +++ b/python/jittor/extern/acl/aclops/flashattention_op_acl.h @@ -0,0 +1,27 @@ +#pragma once +#include "utils.h" +#include "base_op.h" + +namespace jittor +{ + class FlashAttentionOpRunner : public BaseOpRunner + { + + protected: + void executeOp(std::unordered_map::iterator &it) override; + + public: + FlashAttentionOpRunner(); + }; + + class FlashAttentionBackwardOpRunner : public BaseOpRunner + { + + protected: + void executeOp(std::unordered_map::iterator &it) override; + + public: + FlashAttentionBackwardOpRunner(); + }; + +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/flip_op.py b/python/jittor/extern/acl/aclops/flip_op.py new file mode 100644 index 00000000..f05c8be7 --- /dev/null +++ b/python/jittor/extern/acl/aclops/flip_op.py @@ -0,0 +1,85 @@ +import os +from jittor_utils import env_or_try_find +import jittor_utils +import ctypes +import glob +import jittor.compiler as compiler +import jittor as jt +import math +import numpy as np + +from typing import Union +from collections.abc import Sequence, Iterable + + +def flip_cmd(name: str, + inputs: list, + output_dtypes: list = None, + output_shapes: list = None, + attr_code: str = "", + attr_header: str = "", + outputs: list = None): + attr_header = "\nnamespace jittor{" + attr_header + "}\n" + + cuda_header = ''' + #include "acl/aclops/aclops.h" + ''' + outputs_ = [] + if outputs is not None: + outputs_ = outputs + else: + assert output_dtypes is not None + assert output_shapes is not None + assert len(output_dtypes) == len(output_shapes) + for i in range(len(output_shapes)): + outputs_.append(jt.empty(output_shapes[i], output_dtypes[i])) + input_code = '' + for i in range(len(inputs)): + input_code += f"op.add(in{i}, true);\n" + + output_code = '' + for i in range(len(outputs_)): + output_code += f"op.add(out{i}, false);\n" + return jt.code(outputs=outputs_, + inputs=inputs, + cuda_header=attr_header + cuda_header, + cuda_src=f""" + + // aclop + {name}OpRunner op; + {input_code} + {output_code} + {attr_code} + op.run();""") + + +class FlipACL(jt.Function): + + def __init__(self): + super(FlipACL, self).__init__() + + def execute(self, input, dim): + if type(dim) is tuple: + dim = list(dim) + if type(dim) is not list: + dim = [dim] + attr_code = f""" + op.jt_name = "flip"; + ReduceAttr *attr = new ReduceAttr(); + attr->axes = {{{', '.join(map(str, (list(dim))))}}}; + attr->prod_dim = {len(dim)}; + op.op_attr.reset(attr); + """ + self.attr_code = attr_code + result = flip_cmd("Flip", [input], + output_dtypes=[input.dtype], + output_shapes=[input.shape], + attr_code=self.attr_code)[0] + return result + + def grad(self, grad_output): + grad_input = flip_cmd("Flip", [grad_output], + output_dtypes=[grad_output.dtype], + output_shapes=[grad_output.shape], + attr_code=self.attr_code)[0] + return grad_input diff --git a/python/jittor/extern/acl/aclops/flip_op_acl.cc b/python/jittor/extern/acl/aclops/flip_op_acl.cc new file mode 100644 index 00000000..273168a4 --- /dev/null +++ b/python/jittor/extern/acl/aclops/flip_op_acl.cc @@ -0,0 +1,58 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include "common.h" +#include "op.h" +#include "acl_jittor.h" +#include "ops/random_op.h" +#include "ops/reduce_op.h" +#include "ops/binary_op.h" +#include "ops/broadcast_to_op.h" +#include "ops/transpose_op.h" +#include "ops/array_op.h" +#include "ops/code_op.h" +#include "fused_op.h" +#include "ops/unary_op.h" +#include "ops/ternary_op.h" +#include "executor.h" +#include "misc/cuda_flags.h" +#include "mem/allocator.h" +#include "op_compiler.h" +#include "ops/op_register.h" +#include "opt/tuner_manager.h" +#include "utils/str_utils.h" +#include "aclnn/aclnn.h" +#include "flip_op_acl.h" + +namespace jittor +{ + FlipOpRunner::FlipOpRunner() : BaseOpRunner("Flip") + { + } + + void FlipOpRunner::executeOp(std::unordered_map::iterator &it) + { + auto attr = dynamic_cast(op_attr.get()); + auto dim = aclCreateIntArray(attr->axes.data(), attr->axes.size()); + ret = aclnnFlipGetWorkspaceSize(inputTensors[0], dim, outputTensors[0], &workspaceSize, &executor); + + checkRet(ret); + + if (workspaceSize > 0) + { + mallocWorkSpace(workspaceSize); + } + + ret = aclnnFlip(workspaceAddr, workspaceSize, executor, aclstream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnFlip failed. ERROR: %d\n", name.c_str(), ret); return); + + syncRun(); + return; + } + +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/flip_op_acl.h b/python/jittor/extern/acl/aclops/flip_op_acl.h new file mode 100644 index 00000000..5b53700a --- /dev/null +++ b/python/jittor/extern/acl/aclops/flip_op_acl.h @@ -0,0 +1,16 @@ +#pragma once +#include "utils.h" +#include "base_op.h" + +namespace jittor +{ + class FlipOpRunner : public BaseOpRunner + { + + protected: + void executeOp(std::unordered_map::iterator &it) override; + + public: + FlipOpRunner(); + }; +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/floor_op.py b/python/jittor/extern/acl/aclops/floor_op.py new file mode 100644 index 00000000..35bed012 --- /dev/null +++ b/python/jittor/extern/acl/aclops/floor_op.py @@ -0,0 +1,70 @@ +import os +from jittor_utils import env_or_try_find +import jittor_utils +import ctypes +import glob +import jittor.compiler as compiler +import jittor as jt +import math +import numpy as np + +from typing import Union +from collections.abc import Sequence, Iterable + + +def floor_cmd(name: str, + inputs: list, + output_dtypes: list = None, + output_shapes: list = None, + attr_code: str = "", + attr_header: str = "", + outputs: list = None): + attr_header = "\nnamespace jittor{" + attr_header + "}\n" + + cuda_header = ''' + #include "acl/aclops/aclops.h" + ''' + outputs_ = [] + if outputs is not None: + outputs_ = outputs + else: + assert output_dtypes is not None + assert output_shapes is not None + assert len(output_dtypes) == len(output_shapes) + for i in range(len(output_shapes)): + outputs_.append(jt.empty(output_shapes[i], output_dtypes[i])) + input_code = '' + for i in range(len(inputs)): + input_code += f"op.add(in{i}, true);\n" + + output_code = '' + for i in range(len(outputs_)): + output_code += f"op.add(out{i}, false);\n" + return jt.code(outputs=outputs_, + inputs=inputs, + cuda_header=attr_header + cuda_header, + cuda_src=f""" + + // aclop + {name}OpRunner op; + {input_code} + {output_code} + {attr_code} + op.run();""") + + +class FloorIntACL(jt.Function): + + def __init__(self): + super(FloorIntACL, self).__init__() + + def execute(self, input): + self.shape = input.shape + result = floor_cmd("Floor", [input], + output_dtypes=[input.dtype], + output_shapes=[input.shape], + attr_code="op.jt_name=\"floor\";")[0] + return result + + def grad(self, grad_output): + return jt.zeros(self.shape, dtype=grad_output.dtype) diff --git a/python/jittor/extern/acl/aclops/floor_op_acl.cc b/python/jittor/extern/acl/aclops/floor_op_acl.cc new file mode 100644 index 00000000..46118310 --- /dev/null +++ b/python/jittor/extern/acl/aclops/floor_op_acl.cc @@ -0,0 +1,56 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include "common.h" +#include "op.h" +#include "acl_jittor.h" +#include "ops/random_op.h" +#include "ops/reduce_op.h" +#include "ops/binary_op.h" +#include "ops/broadcast_to_op.h" +#include "ops/transpose_op.h" +#include "ops/array_op.h" +#include "ops/code_op.h" +#include "fused_op.h" +#include "ops/unary_op.h" +#include "ops/ternary_op.h" +#include "executor.h" +#include "misc/cuda_flags.h" +#include "mem/allocator.h" +#include "op_compiler.h" +#include "ops/op_register.h" +#include "opt/tuner_manager.h" +#include "utils/str_utils.h" +#include "aclnn/aclnn.h" +#include "floor_op_acl.h" + +namespace jittor +{ + FloorOpRunner::FloorOpRunner() : BaseOpRunner("Floor") + { + } + + void FloorOpRunner::executeOp(std::unordered_map::iterator &it) + { + ret = aclnnFloorGetWorkspaceSize(inputTensors[0], outputTensors[0], &workspaceSize, &executor); + + checkRet(ret); + + if (workspaceSize > 0) + { + mallocWorkSpace(workspaceSize); + } + + ret = aclnnFloor(workspaceAddr, workspaceSize, executor, aclstream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnFloor failed. ERROR: %d\n", name.c_str(), ret); return); + + syncRun(); + return; + } + +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/floor_op_acl.h b/python/jittor/extern/acl/aclops/floor_op_acl.h new file mode 100644 index 00000000..3e228b16 --- /dev/null +++ b/python/jittor/extern/acl/aclops/floor_op_acl.h @@ -0,0 +1,16 @@ +#pragma once +#include "utils.h" +#include "base_op.h" + +namespace jittor +{ + class FloorOpRunner : public BaseOpRunner + { + + protected: + void executeOp(std::unordered_map::iterator &it) override; + + public: + FloorOpRunner(); + }; +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/gather_scatter_op.py b/python/jittor/extern/acl/aclops/gather_scatter_op.py new file mode 100644 index 00000000..748c5718 --- /dev/null +++ b/python/jittor/extern/acl/aclops/gather_scatter_op.py @@ -0,0 +1,126 @@ +import os +from jittor_utils import env_or_try_find +import jittor_utils +import ctypes +import glob +import jittor.compiler as compiler +import jittor as jt +import math +import numpy as np + +from typing import Union +from collections.abc import Sequence, Iterable + + +def gather_scatter_cmd(name: str, + inputs: list, + output_dtypes: list = None, + output_shapes: list = None, + attr_code: str = "", + attr_header: str = "", + outputs: list = None): + attr_header = "\nnamespace jittor{" + attr_header + "}\n" + + cuda_header = ''' + #include "acl/aclops/aclops.h" + ''' + outputs_ = [] + if outputs is not None: + outputs_ = outputs + else: + assert output_dtypes is not None + assert output_shapes is not None + assert len(output_dtypes) == len(output_shapes) + for i in range(len(output_shapes)): + outputs_.append(jt.empty(output_shapes[i], output_dtypes[i])) + input_code = '' + for i in range(len(inputs)): + input_code += f"op.add(in{i}, true);\n" + + output_code = '' + for i in range(len(outputs_)): + output_code += f"op.add(out{i}, false);\n" + return jt.code(outputs=outputs_, + inputs=inputs, + cuda_header=attr_header + cuda_header, + cuda_src=f""" + + // aclop + {name}OpRunner op; + {input_code} + {output_code} + {attr_code} + op.run();""") + + +class GatherACL(jt.Function): + + def __init__(self): + super(GatherACL, self).__init__() + + def execute(self, input, dim, index): + self.dim = dim + self.index = index + attr_code = f""" + op.jt_name = "gather"; + GatherAttr *attr = new GatherAttr(); + attr->dim = {dim}; + op.op_attr.reset(attr); + """ + result = gather_scatter_cmd("Gather", [input, index], + output_dtypes=[input.dtype], + output_shapes=[index.shape], + attr_code=attr_code)[0] + return result + + def grad(self, grad_output): + tmp = jt.zeros(self.index.shape, dtype=grad_output.dtype) + attr_code = f""" + op.jt_name = "scatter"; + ScatterAttr *attr = new ScatterAttr(); + attr->axis = {self.dim}; + attr->reduction = {1}; + op.op_attr.reset(attr); + """ + grad_input = gather_scatter_cmd("Scatter", + [tmp, self.index, grad_output], + output_dtypes=[grad_output.dtype], + output_shapes=[tmp.shape], + attr_code=attr_code)[0] + return grad_input + + +class ScatterACL(jt.Function): + + def __init__(self): + super(ScatterACL, self).__init__() + + def execute(self, input, dim, index, src, reduce='void'): + self.dim = dim + self.index = index + self.reduce = reduce + attr_code = f""" + op.jt_name = "scatter"; + ScatterAttr *attr = new ScatterAttr(); + attr->axis = {dim}; + attr->reduction = {1 if reduce == 'add' else 2 if reduce == 'mul' else 0}; + op.op_attr.reset(attr); + """ + result = gather_scatter_cmd("Scatter", [input, self.index, src], + output_dtypes=[input.dtype], + output_shapes=[input.shape], + attr_code=attr_code)[0] + return result + + def grad(self, grad_output): + attr_code = f""" + op.jt_name = "gather"; + GatherAttr *attr = new GatherAttr(); + attr->dim = {self.dim}; + op.op_attr.reset(attr); + """ + grad_input = gather_scatter_cmd("Gather", [grad_output, self.index], + output_dtypes=[grad_output.dtype], + output_shapes=[self.index.shape], + attr_code=attr_code)[0] + return grad_output, None, None, grad_input diff --git a/python/jittor/extern/acl/aclops/gather_scatter_op_acl.cc b/python/jittor/extern/acl/aclops/gather_scatter_op_acl.cc new file mode 100644 index 00000000..871f5e83 --- /dev/null +++ b/python/jittor/extern/acl/aclops/gather_scatter_op_acl.cc @@ -0,0 +1,80 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include "common.h" +#include "op.h" +#include "acl_jittor.h" +#include "ops/random_op.h" +#include "ops/reduce_op.h" +#include "ops/binary_op.h" +#include "ops/broadcast_to_op.h" +#include "ops/transpose_op.h" +#include "ops/array_op.h" +#include "ops/code_op.h" +#include "fused_op.h" +#include "ops/unary_op.h" +#include "ops/ternary_op.h" +#include "executor.h" +#include "misc/cuda_flags.h" +#include "mem/allocator.h" +#include "op_compiler.h" +#include "ops/op_register.h" +#include "opt/tuner_manager.h" +#include "utils/str_utils.h" +#include "aclnn/aclnn.h" +#include "gather_scatter_op_acl.h" + +namespace jittor +{ + GatherOpRunner::GatherOpRunner() : BaseOpRunner("Gather") + { + } + + void GatherOpRunner::executeOp(std::unordered_map::iterator &it) + { + auto attr = dynamic_cast(op_attr.get()); + ret = aclnnGatherGetWorkspaceSize(inputTensors[0], attr->dim, inputTensors[1], outputTensors[0], &workspaceSize, &executor); + + checkRet(ret); + + if (workspaceSize > 0) + { + mallocWorkSpace(workspaceSize); + } + + ret = aclnnGather(workspaceAddr, workspaceSize, executor, aclstream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnGather failed. ERROR: %d\n", name.c_str(), ret); return); + + syncRun(); + return; + } + + ScatterOpRunner::ScatterOpRunner() : BaseOpRunner("Scatter") + { + } + + void ScatterOpRunner::executeOp(std::unordered_map::iterator &it) + { + auto attr = dynamic_cast(op_attr.get()); + ret = aclnnScatterGetWorkspaceSize(inputTensors[0], attr->axis, inputTensors[1], inputTensors[2], attr->reduction, outputTensors[0], &workspaceSize, &executor); + + checkRet(ret); + + if (workspaceSize > 0) + { + mallocWorkSpace(workspaceSize); + } + + ret = aclnnScatter(workspaceAddr, workspaceSize, executor, aclstream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnScatter failed. ERROR: %d\n", name.c_str(), ret); return); + + syncRun(); + return; + } + +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/gather_scatter_op_acl.h b/python/jittor/extern/acl/aclops/gather_scatter_op_acl.h new file mode 100644 index 00000000..dd95814f --- /dev/null +++ b/python/jittor/extern/acl/aclops/gather_scatter_op_acl.h @@ -0,0 +1,26 @@ +#pragma once +#include "utils.h" +#include "base_op.h" + +namespace jittor +{ + class GatherOpRunner : public BaseOpRunner + { + + protected: + void executeOp(std::unordered_map::iterator &it) override; + + public: + GatherOpRunner(); + }; + + class ScatterOpRunner : public BaseOpRunner + { + + protected: + void executeOp(std::unordered_map::iterator &it) override; + + public: + ScatterOpRunner(); + }; +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/getitem_op.py b/python/jittor/extern/acl/aclops/getitem_op.py new file mode 100644 index 00000000..91fc5d02 --- /dev/null +++ b/python/jittor/extern/acl/aclops/getitem_op.py @@ -0,0 +1,419 @@ +import os +from jittor_utils import env_or_try_find +import jittor_utils +import ctypes +import glob +import jittor.compiler as compiler +import jittor as jt +import math +import numpy as np + +from typing import Union +from collections.abc import Sequence, Iterable + + +def getitem_cmd(name: str, + inputs: list, + output_dtypes: list = None, + output_shapes: list = None, + attr_code: str = "", + attr_header: str = "", + outputs: list = None): + attr_header = "\nnamespace jittor{" + attr_header + "}\n" + + cuda_header = ''' + #include "acl/aclops/aclops.h" + ''' + outputs_ = [] + if outputs is not None: + outputs_ = outputs + else: + assert output_dtypes is not None + assert output_shapes is not None + assert len(output_dtypes) == len(output_shapes) + for i in range(len(output_shapes)): + outputs_.append(jt.empty(output_shapes[i], output_dtypes[i])) + input_code = '' + for i in range(len(inputs)): + input_code += f"op.add(in{i}, true);\n" + + output_code = '' + for i in range(len(outputs_)): + output_code += f"op.add(out{i}, false);\n" + return jt.code(outputs=outputs_, + inputs=inputs, + cuda_header=attr_header + cuda_header, + cuda_src=f""" + + // aclop + {name}OpRunner op; + {input_code} + {output_code} + {attr_code} + op.run();""") + + +def getitem_forward(name: str, + inputs: list, + output_dtypes: list = None, + output_shapes: list = None, + attr_code: str = "", + attr_header: str = "", + outputs: list = None, + extra_data: dict = {}): + attr_header = "\nnamespace jittor{" + attr_header + "}\n" + + cuda_header = ''' + #include "acl/aclops/aclops.h" + ''' + outputs_ = [] + if outputs is not None: + outputs_ = outputs + else: + assert output_dtypes is not None + assert output_shapes is not None + assert len(output_dtypes) == len(output_shapes) + for i in range(len(output_shapes)): + outputs_.append(jt.empty(output_shapes[i], output_dtypes[i])) + input_code = '' + for i in range(len(inputs)): + input_code += f"op.add(in{i}, true);\n" + + output_code = '' + for i in range(len(outputs_)): + output_code += f"op.add(out{i}, false);\n" + return jt.code(outputs=outputs_, + inputs=inputs, + cuda_header=attr_header + cuda_header, + cuda_src=f""" + // aclop + {name}OpRunner op; + {input_code} + op.add(out0, false); + {attr_code} + op.run();""", + data=extra_data) + + +def caculate_shape(tensors): + if isinstance(tensors, jt.Var): + # tensors = tensors[0] + return tensors.shape + elif isinstance(tensors, (int, float)): + return [] + elif isinstance(tensors, (list, tuple)): + # return [caculate_shape(tensor) for tensor in tensors] + sub_shape = caculate_shape(tensors[0]) + return [len(tensors)] + sub_shape + else: + assert False, f"not implemented for {type(tensors)}" + + +def can_broadcast_and_shape(shape1, shape2): + """ + 检查两个张量是否可以广播,并返回广播后的形状。 + + 参数: + - shape1: 第一个张量的形状(tuple 或 list) + - shape2: 第二个张量的形状(tuple 或 list) + + 返回: + - can_broadcast: 布尔值,表示是否可以广播 + - broadcast_shape: 如果可以广播,返回广播后的形状;否则返回 None + """ + # 将形状转换为元组,以防输入是列表 + shape1 = tuple(shape1) + shape2 = tuple(shape2) + + # 使两个形状的长度一致,通过在前面补1 + len1, len2 = len(shape1), len(shape2) + if len1 < len2: + shape1 = (1, ) * (len2 - len1) + shape1 + elif len2 < len1: + shape2 = (1, ) * (len1 - len2) + shape2 + + broadcast_shape = [] + + # 从最后一维开始检查每一维度 + for dim1, dim2 in zip(shape1, shape2): + if dim1 == dim2: + broadcast_shape.append(dim1) + elif dim1 == 1: + broadcast_shape.append(dim2) + elif dim2 == 1: + broadcast_shape.append(dim1) + else: + # 如果在某一维度上不兼容,则不能广播 + return False, None + + return True, tuple(broadcast_shape) + + +class GetItemACL(jt.Function): + + def __init__(self): + self.type_ = 'notype' + + def stride(self, x, dim): + stride = 1 + for i in range(dim + 1, len(x.shape)): + stride *= x.shape[i] + return stride + + def execute(self, x, slices, return_x=None): + if isinstance(slices, jt.Var) and slices.dtype == 'bool': + # assert False, "not support bool type now" + #TODO:优化 + assert x.shape == slices.shape, "shape not match" + output_len = slices.sum().item() + # output = jt.empty((output_len,),dtype=x.dtype) + x_len = x.numel() + output = jt.empty((x_len), dtype=x.dtype) + outputs = [output] + inputs = [x, slices] + # print(inputs,outputs) + # print(output.shape) + self.mask = slices + self.type_ = 'mask' + attr_code = f""" + op.jt_name = "maskedselect"; + """ + result = getitem_cmd("MaskedSelect", + inputs=inputs, + outputs=outputs, + attr_code=attr_code)[0] + result = result[:output_len] + result.sync() + return result + self.x_shape = x.shape + if not isinstance(slices, tuple): + slices = (slices, ) + slices = list(slices) + for i, s in enumerate(slices): + if isinstance(s, int) and s < 0: + slices[i] = s + x.shape[i] + slices = tuple(slices) + slices_list = list(slices) + # if not isinstance(slices[0], slice): + #check slices contains slice type + contains_slice = False + for s in slices: + if not isinstance(s, jt.Var) and (isinstance(s, slice) + or s == Ellipsis): + contains_slice = True + break + if not contains_slice: + indices = [] + output_shape = [] + slices_len = len(slices) + boardcast_shape = caculate_shape(slices_list[0]) + for ii in range(1, len(slices)): + dd, boardcast_shape = can_broadcast_and_shape( + boardcast_shape, caculate_shape(slices_list[ii])) + assert dd is True, "can not broadcast" + output_shape = boardcast_shape + output_shape += x.shape[slices_len:] + if output_shape == []: + output_shape = [1] + for ii in slices: + indices.append(jt.Var(ii).int32()) + if isinstance(slices[0], + jt.Var) or isinstance(slices[0], int) or isinstance( + slices[0], list) or isinstance(slices[0], tuple): + self.indices = indices + inputs = [x] + indices + attr_code = f""" + op.jt_name = "index"; + """ + self.type_ = 'index' + result = getitem_cmd("Index", + inputs=inputs, + output_dtypes=[x.dtype], + output_shapes=[output_shape], + attr_code=attr_code)[0] + result.sync() + return result + assert contains_slice, "slice type error" + x_dim = len(x.shape) + slices = list(slices) + for s in slices: + if not isinstance(s, jt.Var) and s == Ellipsis: + slices = slices[:slices.index(s)] + [ + slice(None, None, None) + ] * (x_dim - len(slices) + 1) + slices[slices.index(s) + 1:] + break + slices = tuple(slices) + + if len(slices) < x_dim: + slices += (slice(None, None, None), ) * (x_dim - len(slices)) + inputs = [x] + sizes = [] + begins = [] + ends = [] + steps = [] + dims = [] + squeeze_dims = [] + + extra_data = {} + if len(slices): + extra_data["a"] = len(slices) + for dim, s in enumerate(slices): + if isinstance(s, int): + s = slice(s, s + 1, 1) + squeeze_dims.append(dim) + if isinstance(s, jt.Var): + assert False, "jt.Var not supported" + start, stop, step = s.indices(x.size(dim)) + size = (stop - start - 1) // step + 1 + # stride = self.stride(x, dim) * step + sizes.append(size) + extra_data[str(dim * 3)] = start + extra_data[str(dim * 3 + 1)] = stop + extra_data[str(dim * 3 + 2)] = step + + steps.append(step) + begins.append(start) + ends.append(stop) + dims.append(dim) + else: + extra_data["a"] = -1 + sizes = [1] + steps = [1] + self.type_ = 'slicev2' + # for backward + self.begins = begins + self.ends = ends + self.steps = steps + self.dims = dims + + self.slices = slices + attr_code = """ + op.jt_name = "slicev2"; + StrideAttr *attr = new StrideAttr(); + + int slice_dim = data["a"]; + + if(slice_dim == -1) { + attr->begins = {}; + attr->ends = {}; + attr->steps = {1}; + attr->axes = {}; + } else { + vector begins; + vector ends; + vector steps; + vector dims; + for(int dim = 0; dim < slice_dim; dim++) { + dims.push_back(dim); + begins.push_back(data[std::to_string(dim*3)]); + ends.push_back(data[std::to_string(dim*3+1)]); + steps.push_back(data[std::to_string(dim*3+2)]); + } + attr->begins = begins; + attr->ends = ends; + attr->steps = steps; + attr->axes = dims; + } + op.op_attr.reset(attr); + """ + result = getitem_forward("SliceV2", + inputs, + output_dtypes=[x.dtype], + output_shapes=[jt.empty(sizes).shape], + attr_code=attr_code, + extra_data=extra_data)[0] + self.squeeze_dims = squeeze_dims + for dim in squeeze_dims[::-1]: + result = jt.squeeze(result, dim) + result.sync() + return result + + def grad(self, grad_output): + if self.type_ == 'index': + indices = self.indices + inputs = [grad_output] + indices + attr_code = f""" + op.jt_name = "indexputimplaccumulate"; + """ + outputs = [jt.zeros(self.x_shape, dtype=grad_output.dtype)] + # breakpoint() + result = getitem_cmd("IndexPutImplAccumulate", + inputs=inputs, + outputs=outputs, + attr_code=attr_code)[0] + result.sync() + return result, None + elif self.type_ == 'slicev2': + begins = self.begins + ends = self.ends + steps = self.steps + dims = self.dims + slices = self.slices + #注意前向的维数可能会被压缩,所以这里要还原 + for dim in self.squeeze_dims: + grad_output = jt.unsqueeze(grad_output, dim) + #适配华为奇怪的要求,最后一个维度的step必须是1 + expand_dim = False + if isinstance(slices[-1], slice): + if slices[-1].step is not None and slices[-1].step != 1: + slices = slices + (slice(None, None, None), ) + expand_dim = True + elif isinstance(slices[-1], int): + #注意最后一个维度是数字 + slices = list(slices) + slices[-1] = slice(slices[-1], slices[-1] + 1, 1) + slices = tuple(slices) + slices = slices + (slice(None, None, None), ) + expand_dim = True + else: + assert False, "not supported" + # x = x.unsqueeze(-1) + if expand_dim: + grad_output = grad_output.unsqueeze(-1) + self.x_shape = self.x_shape + (1, ) + sizes = [] + begins = [] + ends = [] + steps = [] + dims = [] + for dim, s in enumerate(slices): + if isinstance(s, int): + s = slice(s, s + 1, 1) + # squeeze_dims.append(dim) + if isinstance(s, jt.Var): + assert False, "jt.Var not supported" + start, stop, step = s.indices(self.x_shape[dim]) + size = (stop - start - 1) // step + 1 + # stride = self.stride(x, dim) * step + sizes.append(size) + steps.append(step) + begins.append(start) + ends.append(stop) + dims.append(dim) + if not sizes: + sizes = [1] + steps = [1] + attr_code = f""" + op.jt_name = "stridedsliceassignv2"; + StrideAttr *attr = new StrideAttr(); + attr->begins = {{ {", ".join(map(str, begins))} }}; + attr->ends = {{ {", ".join(map(str, ends))} }}; + attr->steps = {{ {", ".join(map(str, steps))} }}; + attr->axes = {{ {", ".join(map(str, dims))} }}; + op.op_attr.reset(attr); + """ + inputs = [grad_output] + outputs = [jt.zeros(self.x_shape, dtype=grad_output.dtype)] + result = getitem_cmd("StridedSliceAssignV2", + inputs=inputs, + outputs=outputs, + attr_code=attr_code)[0] + result.sync() + if expand_dim: + result = result.squeeze(-1) + return result, None + elif self.type_ == 'mask': + return self.mask.float() + pass + else: + assert False, f"grad not implemented for {self.type_}" diff --git a/python/jittor/extern/acl/aclops/getitem_op_acl.cc b/python/jittor/extern/acl/aclops/getitem_op_acl.cc new file mode 100644 index 00000000..4c7c34d3 --- /dev/null +++ b/python/jittor/extern/acl/aclops/getitem_op_acl.cc @@ -0,0 +1,165 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include "common.h" +#include "op.h" +#include "acl_jittor.h" +#include "ops/random_op.h" +#include "ops/reduce_op.h" +#include "ops/binary_op.h" +#include "ops/broadcast_to_op.h" +#include "ops/transpose_op.h" +#include "ops/array_op.h" +#include "ops/code_op.h" +#include "fused_op.h" +#include "ops/unary_op.h" +#include "ops/ternary_op.h" +#include "executor.h" +#include "misc/cuda_flags.h" +#include "mem/allocator.h" +#include "op_compiler.h" +#include "ops/op_register.h" +#include "opt/tuner_manager.h" +#include "utils/str_utils.h" +#include "aclnn/aclnn.h" +#include "getitem_op_acl.h" + +namespace jittor +{ + MaskedSelectOpRunner::MaskedSelectOpRunner() : BaseOpRunner("MaskedSelect") + { + } + + void MaskedSelectOpRunner::executeOp(std::unordered_map::iterator &it) + { + ret = aclnnMaskedSelectGetWorkspaceSize(inputTensors[0], inputTensors[1], outputTensors[0], &workspaceSize, &executor); + + checkRet(ret); + + if (workspaceSize > 0) + { + mallocWorkSpace(workspaceSize); + } + + ret = aclnnMaskedSelect(workspaceAddr, workspaceSize, executor, aclstream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnMaskedSelect failed. ERROR: %d\n", name.c_str(), ret); return); + + syncRun(); + return; + } + + IndexOpRunner::IndexOpRunner() : BaseOpRunner("Index") + { + } + + void IndexOpRunner::executeOp(std::unordered_map::iterator &it) + { + auto input_num = in_.size(); + auto indexTensorList = aclCreateTensorList(&inputTensors[1], input_num - 1); + ret = aclnnIndexGetWorkspaceSize(inputTensors[0], indexTensorList, outputTensors[0], &workspaceSize, &executor); + + checkRet(ret); + + if (workspaceSize > 0) + { + mallocWorkSpace(workspaceSize); + } + + ret = aclnnIndex(workspaceAddr, workspaceSize, executor, aclstream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnIndex failed. ERROR: %d\n", name.c_str(), ret); return); + + syncRun(); + return; + } + + SliceV2OpRunner::SliceV2OpRunner() : BaseOpRunner("SliceV2") + { + } + + void SliceV2OpRunner::executeOp(std::unordered_map::iterator &it) + { + auto attr = dynamic_cast(op_attr.get()); + auto begins = aclCreateIntArray(attr->begins.data(), attr->begins.size()); + auto ends = aclCreateIntArray(attr->ends.data(), attr->ends.size()); + auto steps = aclCreateIntArray(attr->steps.data(), attr->steps.size()); + auto axes = aclCreateIntArray(attr->axes.data(), attr->axes.size()); + ret = aclnnSliceV2GetWorkspaceSize(inputTensors[0], begins, ends, axes, steps, outputTensors[0], &workspaceSize, &executor); + + checkRet(ret); + + if (workspaceSize > 0) + { + mallocWorkSpace(workspaceSize); + } + + ret = aclnnSliceV2(workspaceAddr, workspaceSize, executor, aclstream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnSliceV2 failed. ERROR: %d\n", name.c_str(), ret); return); + + syncRun(); + + return; + } + + IndexPutImplAccumulateOpRunner::IndexPutImplAccumulateOpRunner() : BaseOpRunner("IndexPutImplAccumulate") + { + } + + void IndexPutImplAccumulateOpRunner::executeOp(std::unordered_map::iterator &it) + { + auto input_num = in_.size(); + std::vector indexTensorList = {}; + for (int i = 1; i < input_num; i++) + { + indexTensorList.push_back(inputTensors[i]); + } + auto indexTensorListInput = aclCreateTensorList(&indexTensorList[0], input_num - 1); + ret = aclnnIndexPutImplGetWorkspaceSize(outputTensors[0], indexTensorListInput, inputTensors[0], true, true, &workspaceSize, &executor); + + checkRet(ret); + + if (workspaceSize > 0) + { + mallocWorkSpace(workspaceSize); + } + + ret = aclnnIndexPutImpl(workspaceAddr, workspaceSize, executor, aclstream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnIndexPutImpl failed. ERROR: %d\n", name.c_str(), ret); return); + + syncRun(); + + return; + } + + StridedSliceAssignV2OpRunner::StridedSliceAssignV2OpRunner() : BaseOpRunner("StridedSliceAssignV2") + { + } + + void StridedSliceAssignV2OpRunner::executeOp(std::unordered_map::iterator &it) + { + auto attr = dynamic_cast(op_attr.get()); + auto begins = aclCreateIntArray(attr->begins.data(), attr->begins.size()); + auto ends = aclCreateIntArray(attr->ends.data(), attr->ends.size()); + auto steps = aclCreateIntArray(attr->steps.data(), attr->steps.size()); + auto axes = aclCreateIntArray(attr->axes.data(), attr->axes.size()); + ret = aclnnStridedSliceAssignV2GetWorkspaceSize(outputTensors[0], inputTensors[0], begins, ends, steps, axes, &workspaceSize, &executor); + + checkRet(ret); + + if (workspaceSize > 0) + { + mallocWorkSpace(workspaceSize); + } + + ret = aclnnStridedSliceAssignV2(workspaceAddr, workspaceSize, executor, aclstream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnStridedSliceAssignV2 failed. ERROR: %d\n", name.c_str(), ret); return); + + syncRun(); + + return; + } +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/getitem_op_acl.h b/python/jittor/extern/acl/aclops/getitem_op_acl.h new file mode 100644 index 00000000..481ab15f --- /dev/null +++ b/python/jittor/extern/acl/aclops/getitem_op_acl.h @@ -0,0 +1,57 @@ +#pragma once +#include "utils.h" +#include "base_op.h" + +namespace jittor +{ + class MaskedSelectOpRunner : public BaseOpRunner + { + + protected: + void executeOp(std::unordered_map::iterator &it) override; + + public: + MaskedSelectOpRunner(); + }; + + class IndexOpRunner : public BaseOpRunner + { + + protected: + void executeOp(std::unordered_map::iterator &it) override; + + public: + IndexOpRunner(); + }; + + class SliceV2OpRunner : public BaseOpRunner + { + + protected: + void executeOp(std::unordered_map::iterator &it) override; + + public: + SliceV2OpRunner(); + }; + + class IndexPutImplAccumulateOpRunner : public BaseOpRunner + { + + protected: + void executeOp(std::unordered_map::iterator &it) override; + + public: + IndexPutImplAccumulateOpRunner(); + }; + + class StridedSliceAssignV2OpRunner : public BaseOpRunner + { + + protected: + void executeOp(std::unordered_map::iterator &it) override; + + public: + StridedSliceAssignV2OpRunner(); + }; + +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/index_op.py b/python/jittor/extern/acl/aclops/index_op.py new file mode 100644 index 00000000..087f2a97 --- /dev/null +++ b/python/jittor/extern/acl/aclops/index_op.py @@ -0,0 +1,107 @@ +import os +from jittor_utils import env_or_try_find +import jittor_utils +import ctypes +import glob +import jittor.compiler as compiler +import jittor as jt +import math +import numpy as np + +from typing import Union +from collections.abc import Sequence, Iterable + + +def range_forward(name: str, + inputs: list, + output_dtypes: list = None, + output_shapes: list = None, + attr_code: str = "", + attr_header: str = "", + outputs: list = None, + extra_data: dict = {}): + attr_header = "\nnamespace jittor{" + attr_header + "}\n" + + cuda_header = ''' + #include "acl/aclops/aclops.h" + ''' + outputs_ = [] + if outputs is not None: + outputs_ = outputs + else: + assert output_dtypes is not None + assert output_shapes is not None + assert len(output_dtypes) == len(output_shapes) + for i in range(len(output_shapes)): + outputs_.append(jt.empty(output_shapes[i], output_dtypes[i])) + input_code = '' + for i in range(len(inputs)): + input_code += f"op.add(in{i}, true);\n" + + output_code = '' + for i in range(len(outputs_)): + output_code += f"op.add(out{i}, false);\n" + return jt.code(outputs=outputs_, + inputs=inputs, + cuda_header=attr_header + cuda_header, + cuda_src=f""" + // aclop + {name}OpRunner op; + {input_code} + op.add(out0, false); + {attr_code} + op.run();""", + data=extra_data) + + +class IndexACL(jt.Function): + + def __init__(self): + super(IndexACL, self).__init__() + + def execute(self, inshape: list, dim=None, dtype="int32"): + # zeros a tensor, shape is inshape, dtype is dtype + dim_input = dim + if dim == None: + dim = [i for i in range(len(inshape))] + elif type(dim) == int: + dim = [dim] + results = [] + extra_data = {} + extra_data["dim_count"] = len(dim) + + for i, d in enumerate(dim): + max_len = inshape[d] + + extra_data[f"dim_{i}_start"] = 0 + extra_data[f"dim_{i}_end"] = max_len + extra_data[f"dim_{i}_step"] = 1 + + tmp = jt.zeros(max_len, dtype=dtype) + range_attr_code = f""" + op.jt_name = "range"; + RangeAttr *attr = new RangeAttr(); + attr->start = data["dim_{i}_start"]; + attr->end = data["dim_{i}_end"]; + attr->step = data["dim_{i}_step"]; + op.op_attr.reset(attr); + """ + result = range_forward("Range", [], + output_dtypes=[tmp.dtype], + output_shapes=[tmp.shape], + attr_code=range_attr_code, + extra_data=extra_data)[0] + broadcast_dims = list(range(len(inshape))) + broadcast_dims.remove(d) + result = jt.broadcast(result, shape=inshape, dims=broadcast_dims) + results.append(result) + + if len(results) != 1 or dim_input == None: + return tuple(results) + elif len(results) == 1 and dim_input != None: + return results[0] + else: + return results + + def grad(self, grad_output): + return grad_output diff --git a/python/jittor/extern/acl/aclops/index_op_acl.cc b/python/jittor/extern/acl/aclops/index_op_acl.cc new file mode 100644 index 00000000..33853b7e --- /dev/null +++ b/python/jittor/extern/acl/aclops/index_op_acl.cc @@ -0,0 +1,72 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include "common.h" +#include "op.h" +#include "acl_jittor.h" +#include "ops/random_op.h" +#include "ops/reduce_op.h" +#include "ops/binary_op.h" +#include "ops/broadcast_to_op.h" +#include "ops/transpose_op.h" +#include "ops/array_op.h" +#include "ops/code_op.h" +#include "fused_op.h" +#include "ops/unary_op.h" +#include "ops/ternary_op.h" +#include "executor.h" +#include "misc/cuda_flags.h" +#include "mem/allocator.h" +#include "op_compiler.h" +#include "ops/op_register.h" +#include "opt/tuner_manager.h" +#include "utils/str_utils.h" +#include "aclnn/aclnn.h" +#include "index_op_acl.h" + +namespace jittor +{ + RangeOpRunner::RangeOpRunner() : BaseOpRunner("Range") + { + } + + void RangeOpRunner::executeOp(std::unordered_map::iterator &it) + { + aclScalar *start = nullptr; + aclScalar *end = nullptr; + aclScalar *step = nullptr; + + auto attr = dynamic_cast(op_attr.get()); + int64_t startValue = attr->start; + int64_t endValue = attr->end; + int64_t stepValue = attr->step; + start = aclCreateScalar(&startValue, aclDataType::ACL_INT64); + end = aclCreateScalar(&endValue, aclDataType::ACL_INT64); + step = aclCreateScalar(&stepValue, aclDataType::ACL_INT64); + + ret = aclnnRangeGetWorkspaceSize(start, end, step, outputTensors[0], &workspaceSize, &executor); + + checkRet(ret); + + if (workspaceSize > 0) + { + mallocWorkSpace(workspaceSize); + } + + ret = aclnnRange(workspaceAddr, workspaceSize, executor, aclstream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnRange failed. ERROR: %d\n", name.c_str(), ret); return); + + syncRun(); + + aclDestroyScalar(start); + aclDestroyScalar(end); + aclDestroyScalar(step); + return; + } + +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/index_op_acl.h b/python/jittor/extern/acl/aclops/index_op_acl.h new file mode 100644 index 00000000..e69bf39c --- /dev/null +++ b/python/jittor/extern/acl/aclops/index_op_acl.h @@ -0,0 +1,17 @@ +#pragma once +#include "utils.h" +#include "base_op.h" + +namespace jittor +{ + class RangeOpRunner : public BaseOpRunner + { + + protected: + void executeOp(std::unordered_map::iterator &it) override; + + public: + RangeOpRunner(); + }; + +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/matmul_op.py b/python/jittor/extern/acl/aclops/matmul_op.py new file mode 100644 index 00000000..fbb8df71 --- /dev/null +++ b/python/jittor/extern/acl/aclops/matmul_op.py @@ -0,0 +1,130 @@ +import os +from jittor_utils import env_or_try_find +import jittor_utils +import ctypes +import glob +import jittor.compiler as compiler +import jittor as jt +import math +import numpy as np + +from typing import Union +from collections.abc import Sequence, Iterable + + +def matmul_forward(name: str, + inputs: list, + output_dtypes: list = None, + output_shapes: list = None, + attr_code: str = "", + attr_header: str = "", + outputs: list = None, + extra_data: dict = {}): + attr_header = "\nnamespace jittor{" + attr_header + "}\n" + + cuda_header = ''' + #include "acl/aclops/aclops.h" + ''' + outputs_ = [] + if outputs is not None: + outputs_ = outputs + else: + assert output_dtypes is not None + assert output_shapes is not None + assert len(output_dtypes) == len(output_shapes) + for i in range(len(output_shapes)): + outputs_.append(jt.empty(output_shapes[i], output_dtypes[i])) + input_code = '' + for i in range(len(inputs)): + input_code += f"op.add(in{i}, true);\n" + + return jt.code(outputs=outputs_, + inputs=inputs, + cuda_header=attr_header + cuda_header, + cuda_src=f""" + // aclop + MatMulOpRunner op; + {input_code} + op.add(out0, false); + {attr_code} + op.run();""", + data=extra_data) + + +class MatmulACL(jt.Function): + + def __init__(self, trans_x2=False): + super(MatmulACL, self).__init__() + self.trans_x2 = trans_x2 + + def execute(self, x1, x2): + self.input = [x1, x2] + result = matmul_forward( + "MatMul", [x1, x2], + output_dtypes=[x1.dtype], + output_shapes=[ + x1.shape[:-1] + + x2.shape[-2:-1] if self.trans_x2 else x1.shape[:-1] + + x2.shape[-1:] + ], + attr_code="op.jt_name=\"matmul_trans_1\";" + if self.trans_x2 else "op.jt_name=\"matmul\";")[0] + return result + + def grad(self, grad_output): + x1, x2 = self.input + if len(x1) != len(x2): + reshape_grad_x2 = True + else: + reshape_grad_x2 = False + grad_x1 = matmul_forward( + "MatMul", [grad_output, x2], + output_dtypes=[x1.dtype], + output_shapes=[ + grad_output.shape[:-1] + x2.shape[-2:-1] if not self.trans_x2 + else grad_output.shape[:-1] + x2.shape[-1:] + ], + attr_code="op.jt_name=\"matmul_trans_1\";" + if not self.trans_x2 else "op.jt_name=\"matmul\";")[0] + + if self.trans_x2: + if reshape_grad_x2: + output_shape = grad_output.shape[1:-2] + grad_output.shape[ + -1:] + x1.shape[-1:] + grad_x2 = matmul_forward( + "MatMul", [ + grad_output.reshape(-1, grad_output.shape[-1]), + x1.reshape(-1, x1.shape[-1]) + ], + output_dtypes=[x2.dtype], + output_shapes=[output_shape], + attr_code="op.jt_name=\"matmul_trans_0\";")[0] + else: + output_shape = grad_output.shape[:-2] + grad_output.shape[ + -1:] + x1.shape[-1:] + grad_x2 = matmul_forward( + "MatMul", [grad_output, x1], + output_dtypes=[x2.dtype], + output_shapes=[output_shape], + attr_code="op.jt_name=\"matmul_trans_0\";")[0] + else: + if reshape_grad_x2: + output_shape = x1.shape[1:-2] + x1.shape[ + -1:] + grad_output.shape[-1:] + grad_x2 = matmul_forward( + "MatMul", [ + x1.reshape(-1, x1.shape[-1]), + grad_output.reshape(-1, grad_output.shape[-1]) + ], + output_dtypes=[x2.dtype], + output_shapes=[output_shape], + attr_code="op.jt_name=\"matmul_trans_0\";")[0] + else: + output_shape = x1.shape[:-2] + x1.shape[ + -1:] + grad_output.shape[-1:] + grad_x2 = matmul_forward( + "MatMul", [x1, grad_output], + output_dtypes=[x2.dtype], + output_shapes=[output_shape], + attr_code="op.jt_name=\"matmul_trans_0\";")[0] + return grad_x1, grad_x2 diff --git a/python/jittor/extern/acl/aclops/matmul_op_acl.cc b/python/jittor/extern/acl/aclops/matmul_op_acl.cc new file mode 100644 index 00000000..af109cbb --- /dev/null +++ b/python/jittor/extern/acl/aclops/matmul_op_acl.cc @@ -0,0 +1,77 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include "common.h" +#include "op.h" +#include "acl_jittor.h" +#include "ops/random_op.h" +#include "ops/reduce_op.h" +#include "ops/binary_op.h" +#include "ops/broadcast_to_op.h" +#include "ops/transpose_op.h" +#include "ops/array_op.h" +#include "ops/code_op.h" +#include "fused_op.h" +#include "ops/unary_op.h" +#include "ops/ternary_op.h" +#include "executor.h" +#include "misc/cuda_flags.h" +#include "mem/allocator.h" +#include "op_compiler.h" +#include "ops/op_register.h" +#include "opt/tuner_manager.h" +#include "utils/str_utils.h" +#include "aclnn/aclnn.h" +#include "matmul_op_acl.h" + +namespace jittor +{ + MatMulOpRunner::MatMulOpRunner() : BaseOpRunner("MatMul") + { + } + void MatMulOpRunner::setupInputDesc() + { + auto input_num = in_.size(); + for (int input_idx = 0; input_idx < input_num; input_idx++) + { + std::vector shape; + for (int j = 0; j < in_[input_idx]->shape.size(); j++) + { + shape.push_back(in_[input_idx]->shape[j]); + } + inputShapes.push_back(shape); + } + for (int idx = 0; idx < input_num; idx++) + { + inputTensors.push_back(nullptr); + if ((jt_name == "matmul_trans_1" && idx == 1) || (jt_name == "matmul_trans_0" && idx == 0)) + { + auto ret = CreateFakeTransAclTensor(inputShapes[idx], in_[idx]->mem_ptr, in_[idx]->size, get_dtype(in_[idx]->dtype()), &inputTensors[idx], use_nchw); + CHECK_RET(ret == ACL_SUCCESS, return); + } + else + { + auto ret = CreateAclTensor(inputShapes[idx], in_[idx]->mem_ptr, in_[idx]->size, get_dtype(in_[idx]->dtype()), &inputTensors[idx], use_nchw); + CHECK_RET(ret == ACL_SUCCESS, return); + } + } + } + void MatMulOpRunner::executeOp(std::unordered_map::iterator &it) + { + + ret = aclnnMatmulGetWorkspaceSize(inputTensors[0], inputTensors[1], outputTensors[0], 1, &workspaceSize, &executor); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnMatmulGetWorkspaceSize failed. ERROR: %d\n", name.c_str(), ret); return); + if (workspaceSize > 0) + { + mallocWorkSpace(workspaceSize); + } + ret = aclnnMatmul(workspaceAddr, workspaceSize, executor, aclstream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnMatmul failed. ERROR: %d\n", name.c_str(), ret); return); + syncRun(); + } +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/matmul_op_acl.h b/python/jittor/extern/acl/aclops/matmul_op_acl.h new file mode 100644 index 00000000..ab82edc0 --- /dev/null +++ b/python/jittor/extern/acl/aclops/matmul_op_acl.h @@ -0,0 +1,17 @@ +#pragma once +#include "utils.h" +#include "base_op.h" + +namespace jittor +{ + class MatMulOpRunner : public BaseOpRunner + { + + protected: + void setupInputDesc() override; + void executeOp(std::unordered_map::iterator &it) override; + + public: + MatMulOpRunner(); + }; +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/nantonum_op.py b/python/jittor/extern/acl/aclops/nantonum_op.py new file mode 100644 index 00000000..2a36c999 --- /dev/null +++ b/python/jittor/extern/acl/aclops/nantonum_op.py @@ -0,0 +1,75 @@ +import os +from jittor_utils import env_or_try_find +import jittor_utils +import ctypes +import glob +import jittor.compiler as compiler +import jittor as jt +import math +import numpy as np + +from typing import Union +from collections.abc import Sequence, Iterable + + +def nantonum_cmd(name: str, + inputs: list, + output_dtypes: list = None, + output_shapes: list = None, + attr_code: str = "", + attr_header: str = "", + outputs: list = None): + attr_header = "\nnamespace jittor{" + attr_header + "}\n" + + cuda_header = ''' + #include "acl/aclops/aclops.h" + ''' + outputs_ = [] + if outputs is not None: + outputs_ = outputs + else: + assert output_dtypes is not None + assert output_shapes is not None + assert len(output_dtypes) == len(output_shapes) + for i in range(len(output_shapes)): + outputs_.append(jt.empty(output_shapes[i], output_dtypes[i])) + input_code = '' + for i in range(len(inputs)): + input_code += f"op.add(in{i}, true);\n" + + output_code = '' + for i in range(len(outputs_)): + output_code += f"op.add(out{i}, false);\n" + return jt.code(outputs=outputs_, + inputs=inputs, + cuda_header=attr_header + cuda_header, + cuda_src=f""" + + // aclop + {name}OpRunner op; + {input_code} + {output_code} + {attr_code} + op.run();""") + + +class NanToNumACL(jt.Function): + + def __init__(self): + super(NanToNumACL, self).__init__() + + def execute(self, input, nan_or_inf): + attr_code = f""" + op.jt_name = "NanToNum"; + NanToNumAttr *attr = new NanToNumAttr(); + attr->nan = {nan_or_inf}; + attr->posinf = {-nan_or_inf}; + attr->neginf = {-nan_or_inf}; + op.op_attr.reset(attr); + """ + self.attr_code = attr_code + result = nantonum_cmd("NanToNum", [input], + output_dtypes=[input[0].dtype], + output_shapes=[input.shape], + attr_code=self.attr_code)[0] + return result diff --git a/python/jittor/extern/acl/aclops/nantonum_op_acl.cc b/python/jittor/extern/acl/aclops/nantonum_op_acl.cc new file mode 100644 index 00000000..e461b292 --- /dev/null +++ b/python/jittor/extern/acl/aclops/nantonum_op_acl.cc @@ -0,0 +1,58 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include "common.h" +#include "op.h" +#include "acl_jittor.h" +#include "ops/random_op.h" +#include "ops/reduce_op.h" +#include "ops/binary_op.h" +#include "ops/broadcast_to_op.h" +#include "ops/transpose_op.h" +#include "ops/array_op.h" +#include "ops/code_op.h" +#include "fused_op.h" +#include "ops/unary_op.h" +#include "ops/ternary_op.h" +#include "executor.h" +#include "misc/cuda_flags.h" +#include "mem/allocator.h" +#include "op_compiler.h" +#include "ops/op_register.h" +#include "opt/tuner_manager.h" +#include "utils/str_utils.h" +#include "aclnn/aclnn.h" +#include "nantonum_op_acl.h" + +namespace jittor +{ + NanToNumOpRunner::NanToNumOpRunner() : BaseOpRunner("NanToNum") + { + } + + void NanToNumOpRunner::executeOp(std::unordered_map::iterator &it) + { + auto attr = dynamic_cast(op_attr.get()); + ret = aclnnNanToNumGetWorkspaceSize(inputTensors[0], attr->nan, attr->posinf, attr->neginf, outputTensors[0], &workspaceSize, &executor); + + checkRet(ret); + + if (workspaceSize > 0) + { + mallocWorkSpace(workspaceSize); + } + + ret = aclnnNanToNum(workspaceAddr, workspaceSize, executor, aclstream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnNanToNum failed. ERROR: %d\n", name.c_str(), ret); return); + + syncRun(); + + return; + } + +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/nantonum_op_acl.h b/python/jittor/extern/acl/aclops/nantonum_op_acl.h new file mode 100644 index 00000000..924c0080 --- /dev/null +++ b/python/jittor/extern/acl/aclops/nantonum_op_acl.h @@ -0,0 +1,17 @@ +#pragma once +#include "utils.h" +#include "base_op.h" + +namespace jittor +{ + class NanToNumOpRunner : public BaseOpRunner + { + + protected: + void executeOp(std::unordered_map::iterator &it) override; + + public: + NanToNumOpRunner(); + }; + +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/norms_op.py b/python/jittor/extern/acl/aclops/norms_op.py new file mode 100644 index 00000000..c724aa8b --- /dev/null +++ b/python/jittor/extern/acl/aclops/norms_op.py @@ -0,0 +1,184 @@ +import os +from jittor_utils import env_or_try_find +import jittor_utils +import ctypes +import glob +import jittor.compiler as compiler +import jittor as jt +import math +import numpy as np + +from typing import Union +from collections.abc import Sequence, Iterable + +def norms_cmd(name: str, + inputs: list, + output_dtypes: list = None, + output_shapes: list = None, + attr_code: str = "", + attr_header: str = "", + outputs: list = None): + attr_header = "\nnamespace jittor{" + attr_header + "}\n" + + cuda_header = ''' + #include "acl/aclops/aclops.h" + ''' + outputs_ = [] + if outputs is not None: + outputs_ = outputs + else: + assert output_dtypes is not None + assert output_shapes is not None + assert len(output_dtypes) == len(output_shapes) + for i in range(len(output_shapes)): + outputs_.append(jt.empty(output_shapes[i], output_dtypes[i])) + input_code = '' + for i in range(len(inputs)): + input_code += f"op.add(in{i}, true);\n" + + output_code = '' + for i in range(len(outputs_)): + output_code += f"op.add(out{i}, false);\n" + return jt.code(outputs=outputs_, + inputs=inputs, + cuda_header=attr_header + cuda_header, + cuda_src=f""" + + // aclop + {name}OpRunner op; + {input_code} + {output_code} + {attr_code} + op.run();""") + +class BatchNormACL(jt.Function): + + def __init__(self, + num_features, + eps=1e-05, + momentum=0.1, + affine=True, + is_train=True, + sync=True): + self.num_features = num_features + self.eps = eps + self.momentum = momentum + self.affine = affine + self.is_train = is_train + self.sync = sync + self.weight = jt.init.constant( + (num_features, ), "float32", 1.0) if affine else 1.0 + self.bias = jt.init.constant( + (num_features, ), "float32", 0.0) if affine else 0.0 + self.running_mean = jt.init.constant((num_features, ), "float32", + 0.0).stop_grad() + self.running_var = jt.init.constant((num_features, ), "float32", + 1.0).stop_grad() + + def execute(self, x): + # assert self.num_features == x.shape[-1] + self.input = x.float32() + inputs = [ + self.input, self.weight, self.bias, self.running_mean, + self.running_var + ] + outputs = [ + jt.empty(x.shape), + jt.empty(self.num_features), + jt.empty(self.num_features) + ] + attr_code = f""" + op.jt_name = "batchnorm"; + BatchNormAttr *attr = new BatchNormAttr(); + attr->is_train = {"true" if self.is_train else "false"}; + attr->momentum = {self.momentum}; + attr->eps = {self.eps}; + op.op_attr.reset(attr); + """ + result = norms_cmd("BatchNorm", + inputs=inputs, + outputs=outputs, + attr_code=attr_code) + self.output = result[0] + self.saveMean = result[1] + self.saveInvstd = result[2] + return self.output + + def grad(self, grad_output): + attr_code = f""" + op.jt_name = "batchnorm"; + BatchNormAttr *attr = new BatchNormAttr(); + attr->is_train = {"true" if self.is_train else "false"}; + attr->momentum = {self.momentum}; + attr->eps = {self.eps}; + op.op_attr.reset(attr); + """ + inputs = [ + grad_output, self.input, self.weight, self.running_mean, + self.running_var, self.saveMean, self.saveInvstd + ] + outputs = [ + jt.empty(self.input.shape), + jt.empty(self.num_features), + jt.empty(self.num_features) + ] + grad_input = norms_cmd("BatchNormBackward", + inputs=inputs, + outputs=outputs, + attr_code=attr_code)[0] + return grad_input + + +class LayerNormACL(jt.Function): + + def __init__(self, + normalized_shape, + eps: float = 1e-5, + elementwise_affine: bool = True): + if isinstance(normalized_shape, int): + normalized_shape = (normalized_shape, ) + self.normalized_shape = tuple(normalized_shape) + self.eps = eps + self.elementwise_affine = elementwise_affine + self.weight = jt.init.constant(normalized_shape, "float32", + 1.0) if elementwise_affine else 1.0 + self.bias = jt.init.constant(normalized_shape, "float32", + 0.0) if elementwise_affine else 0.0 + + def execute(self, x): + self.input = x.float32() + inputs = [self.input, self.weight, self.bias] + outputs = [jt.empty(x.shape), jt.empty(x.shape), jt.empty(x.shape)] + attr_code = f""" + op.jt_name = "layernorm"; + LayerNormAttr *attr = new LayerNormAttr(); + attr->eps = {self.eps}; + attr->normalizedShape = {{{', '.join(map(str, (list(self.normalized_shape))))}}}; + attr->size = {x.shape[-1]}; + op.op_attr.reset(attr); + """ + result = norms_cmd("LayerNorm", + inputs=inputs, + outputs=outputs, + attr_code=attr_code) + self.output = result[0] + self.meanout = result[1] + self.rstdout = result[2] + return self.output + + def grad(self, grad_output): + attr_code = f""" + op.jt_name = "batchnorm"; + BatchNormAttr *attr = new BatchNormAttr(); + attr->is_train = {"true" if self.is_train else "false"}; + attr->momentum = {self.momentum}; + attr->eps = {self.eps}; + op.op_attr.reset(attr); + """ + inputs = [grad_output, self.input, self.weight, self.running_mean, self.running_var, self.saveMean, self.saveInvstd] + outputs = [jt.empty(self.input.shape), jt.empty(self.num_features), jt.empty(self.num_features)] + grad_input = norms_cmd("SoftmaxBackward", + inputs=inputs, + outputs=outputs, + attr_code=attr_code)[0] + return grad_input diff --git a/python/jittor/extern/acl/aclops/norms_op_acl.cc b/python/jittor/extern/acl/aclops/norms_op_acl.cc new file mode 100644 index 00000000..9e450b62 --- /dev/null +++ b/python/jittor/extern/acl/aclops/norms_op_acl.cc @@ -0,0 +1,111 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include "common.h" +#include "op.h" +#include "acl_jittor.h" +#include "ops/random_op.h" +#include "ops/reduce_op.h" +#include "ops/binary_op.h" +#include "ops/broadcast_to_op.h" +#include "ops/transpose_op.h" +#include "ops/array_op.h" +#include "ops/code_op.h" +#include "fused_op.h" +#include "ops/unary_op.h" +#include "ops/ternary_op.h" +#include "executor.h" +#include "misc/cuda_flags.h" +#include "mem/allocator.h" +#include "op_compiler.h" +#include "ops/op_register.h" +#include "opt/tuner_manager.h" +#include "utils/str_utils.h" +#include "aclnn/aclnn.h" +#include "norms_op_acl.h" + +namespace jittor +{ + BatchNormOpRunner::BatchNormOpRunner() : BaseOpRunner("BatchNorm") + { + } + + void BatchNormOpRunner::executeOp(std::unordered_map::iterator &it) + { + auto attr = dynamic_cast(op_attr.get()); + ret = aclnnBatchNormGetWorkspaceSize(inputTensors[0], inputTensors[1], inputTensors[2], inputTensors[3], inputTensors[4], attr->is_train, attr->momentum, attr->eps, outputTensors[0], outputTensors[1], outputTensors[2], &workspaceSize, &executor); + + checkRet(ret); + + if (workspaceSize > 0) + { + mallocWorkSpace(workspaceSize); + } + + ret = aclnnBatchNorm(workspaceAddr, workspaceSize, executor, aclstream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnBatchNorm failed. ERROR: %d\n", name.c_str(), ret); return); + + syncRun(); + + return; + } + + BatchNormBackwardOpRunner::BatchNormBackwardOpRunner() : BaseOpRunner("BatchNormBackward") + { + } + + void BatchNormBackwardOpRunner::executeOp(std::unordered_map::iterator &it) + { + auto attr = dynamic_cast(op_attr.get()); + bool outputMask[3] = {true, true, true}; + aclBoolArray *outMask = aclCreateBoolArray(outputMask, 3); + ret = aclnnBatchNormBackwardGetWorkspaceSize(inputTensors[0], inputTensors[1], inputTensors[2], inputTensors[3], inputTensors[4], inputTensors[5], inputTensors[6], attr->is_train, attr->eps, outMask, outputTensors[0], outputTensors[1], outputTensors[2], &workspaceSize, &executor); + + checkRet(ret); + + if (workspaceSize > 0) + { + mallocWorkSpace(workspaceSize); + } + + ret = aclnnBatchNormBackward(workspaceAddr, workspaceSize, executor, aclstream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnBatchNormBackward failed. ERROR: %d\n", name.c_str(), ret); return); + + syncRun(); + + return; + } + + LayerNormOpRunner::LayerNormOpRunner() : BaseOpRunner("LayerNorm") + { + } + + void LayerNormOpRunner::executeOp(std::unordered_map::iterator &it) + { + auto attr = dynamic_cast(op_attr.get()); + aclIntArray *normalizedShape = nullptr; + normalizedShape = aclCreateIntArray(attr->normalizedShape.data(), attr->size); + ret = aclnnLayerNormGetWorkspaceSize(inputTensors[0], normalizedShape, inputTensors[1], inputTensors[2], attr->eps, outputTensors[0], outputTensors[1], outputTensors[2], &workspaceSize, &executor); + + checkRet(ret); + + if (workspaceSize > 0) + { + mallocWorkSpace(workspaceSize); + } + + ret = aclnnLayerNorm(workspaceAddr, workspaceSize, executor, aclstream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnLayerNorm failed. ERROR: %d\n", name.c_str(), ret); return); + + syncRun(); + aclDestroyIntArray(normalizedShape); + + return; + } + +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/norms_op_acl.h b/python/jittor/extern/acl/aclops/norms_op_acl.h new file mode 100644 index 00000000..2ad7c433 --- /dev/null +++ b/python/jittor/extern/acl/aclops/norms_op_acl.h @@ -0,0 +1,34 @@ +#pragma once +#include "utils.h" +#include "base_op.h" + +namespace jittor +{ + class BatchNormOpRunner : public BaseOpRunner + { + + protected: + void executeOp(std::unordered_map::iterator &it) override; + public: + BatchNormOpRunner(); + }; + + class BatchNormBackwardOpRunner : public BaseOpRunner + { + + protected: + void executeOp(std::unordered_map::iterator &it) override; + public: + BatchNormBackwardOpRunner(); + }; + + class LayerNormOpRunner : public BaseOpRunner + { + + protected: + void executeOp(std::unordered_map::iterator &it) override; + public: + LayerNormOpRunner(); + }; + +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/pool_op.py b/python/jittor/extern/acl/aclops/pool_op.py new file mode 100644 index 00000000..583420a6 --- /dev/null +++ b/python/jittor/extern/acl/aclops/pool_op.py @@ -0,0 +1,176 @@ +import os +from jittor_utils import env_or_try_find +import jittor_utils +import ctypes +import glob +import jittor.compiler as compiler +import jittor as jt +import math +import numpy as np + +from typing import Union +from collections.abc import Sequence, Iterable + + +def pool_cmd(name: str, + inputs: list, + output_dtypes: list = None, + output_shapes: list = None, + attr_code: str = "", + attr_header: str = "", + outputs: list = None): + attr_header = "\nnamespace jittor{" + attr_header + "}\n" + + cuda_header = ''' + #include "acl/aclops/aclops.h" + ''' + outputs_ = [] + if outputs is not None: + outputs_ = outputs + else: + assert output_dtypes is not None + assert output_shapes is not None + assert len(output_dtypes) == len(output_shapes) + for i in range(len(output_shapes)): + outputs_.append(jt.empty(output_shapes[i], output_dtypes[i])) + input_code = '' + for i in range(len(inputs)): + input_code += f"op.add(in{i}, true);\n" + + output_code = '' + for i in range(len(outputs_)): + output_code += f"op.add(out{i}, false);\n" + return jt.code(outputs=outputs_, + inputs=inputs, + cuda_header=attr_header + cuda_header, + cuda_src=f""" + + // aclop + {name}OpRunner op; + {input_code} + {output_code} + {attr_code} + op.run();""") + + +class PoolACL(jt.Function): + + def __init__(self, + kernel_size, + stride=None, + padding=0, + dilation=None, + return_indices=None, + ceil_mode=False, + count_include_pad=True, + op='maximum'): + self.kernel_size = kernel_size if isinstance( + kernel_size, tuple) else (kernel_size, kernel_size) + stride = stride if stride else kernel_size + self.stride = stride if isinstance(stride, tuple) else (stride, stride) + self.padding = padding if isinstance(padding, tuple) else (padding, + padding) + dilation = dilation if dilation else 1 + assert dilation == 1 + self.dilation = dilation if isinstance(dilation, tuple) else (dilation, + dilation) + for item in self.kernel_size: + if item <= 0: + raise RuntimeError( + f"kernel_size must be greater than zero, but got {item}") + for item in self.stride: + if item <= 0: + raise RuntimeError( + f"stride must be greater than zero, but got {item}") + for item in self.padding: + if item < 0: + raise RuntimeError( + f"padding must be non-negative, but got {item}") + self.op = op + self.return_indices = return_indices + self.ceil_mode = ceil_mode + self.count_include_pad = count_include_pad + + def execute(self, input): + self.input = input + attr_code = f""" + op.jt_name = "{"avgpool" if self.op == 'mean' else "maxpool"}"; + PoolAttr *attr = new PoolAttr(); + attr->kernel_size = {{ {self.kernel_size[0]}, {self.kernel_size[1]} }}; + attr->poolStrides = {{ {self.stride[0]}, {self.stride[1]} }}; + attr->poolPads = {{ {self.padding[0]}, {self.padding[1]} }}; + attr->poolDilations = {{ {self.dilation[0]}, {self.dilation[1]} }}; + attr->poolCeil = {"true" if self.ceil_mode else "false"}; + attr->countIncludePad = {"true" if self.count_include_pad else "false"}; + op.op_attr.reset(attr); + """ + input_height, input_width = input.shape[-2:] + kernel_height, kernel_width = self.kernel_size[-2:] + + output_height = (input_height + 2 * self.padding[0] - + (kernel_height - 1) - 1) // self.stride[0] + 1 + output_width = (input_width + 2 * self.padding[1] - + (kernel_width - 1) - 1) // self.stride[1] + 1 + + output_shape = (input.shape[0], input.shape[1], output_height, + output_width) + + inputs = [input] + + if self.op == 'maximum': + result = pool_cmd( + "Maxpool", + inputs, + output_dtypes=[input.dtype, 'int32'], + output_shapes=[output_shape, output_shape], + attr_code=attr_code, + ) + elif self.op == 'mean': + result = pool_cmd( + "Avgpool", + inputs, + output_dtypes=[input.dtype], + output_shapes=[output_shape], + attr_code=attr_code, + ) + else: + raise ValueError('no this type pool') + + if self.op == 'maximum': + self.index = result[1] + + if self.return_indices: + return result[0], result[1] + else: + return result[0] + + def grad(self, grad_output): + input = self.input + attr_code = f""" + op.jt_name = "{"avgpoolbackward" if self.op == 'mean' else "maxpoolbackward"}"; + PoolAttr *attr = new PoolAttr(); + attr->kernel_size = {{ {self.kernel_size[0]}, {self.kernel_size[1]} }}; + attr->poolStrides = {{ {self.stride[0]}, {self.stride[1]} }}; + attr->poolPads = {{ {self.padding[0]}, {self.padding[1]} }}; + attr->poolDilations = {{ {self.dilation[0]}, {self.dilation[1]} }}; + attr->poolCeil = {"true" if self.ceil_mode else "false"}; + attr->countIncludePad = {"true" if self.count_include_pad else "false"}; + op.op_attr.reset(attr); + """ + output_shapes = [input.shape] + output_dtypes = [input.dtype] + if self.op == 'maximum': + result = pool_cmd("MaxpoolBackward", + inputs=[grad_output, input, self.index], + output_dtypes=output_dtypes, + output_shapes=output_shapes, + attr_code=attr_code)[0] + elif self.op == 'mean': + result = pool_cmd("AvgpoolBackward", + inputs=[grad_output, input], + output_dtypes=output_dtypes, + output_shapes=output_shapes, + attr_code=attr_code)[0] + else: + raise ValueError('no this type pool') + return result diff --git a/python/jittor/extern/acl/aclops/pool_op_acl.cc b/python/jittor/extern/acl/aclops/pool_op_acl.cc new file mode 100644 index 00000000..8781b4ee --- /dev/null +++ b/python/jittor/extern/acl/aclops/pool_op_acl.cc @@ -0,0 +1,187 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include "common.h" +#include "op.h" +#include "acl_jittor.h" +#include "ops/random_op.h" +#include "ops/reduce_op.h" +#include "ops/binary_op.h" +#include "ops/broadcast_to_op.h" +#include "ops/transpose_op.h" +#include "ops/array_op.h" +#include "ops/code_op.h" +#include "fused_op.h" +#include "ops/unary_op.h" +#include "ops/ternary_op.h" +#include "executor.h" +#include "misc/cuda_flags.h" +#include "mem/allocator.h" +#include "op_compiler.h" +#include "ops/op_register.h" +#include "opt/tuner_manager.h" +#include "utils/str_utils.h" +#include "aclnn/aclnn.h" +#include "pool_op_acl.h" + +namespace jittor +{ + MaxpoolOpRunner::MaxpoolOpRunner() : BaseOpRunner("Maxpool") + { + use_nchw = true; + } + + void MaxpoolOpRunner::executeOp(std::unordered_map::iterator &it) + { + + aclIntArray *strides = nullptr; + aclIntArray *pads = nullptr; + aclIntArray *dilations = nullptr; + aclIntArray *kernel_size = nullptr; + + auto attr = dynamic_cast(op_attr.get()); + kernel_size = aclCreateIntArray(attr->kernel_size.data(), 2); + strides = aclCreateIntArray(attr->poolStrides.data(), 2); + pads = aclCreateIntArray(attr->poolPads.data(), 2); + dilations = aclCreateIntArray(attr->poolDilations.data(), 2); + ret = aclnnMaxPool2dWithIndicesGetWorkspaceSize(inputTensors[0], kernel_size, strides, pads, dilations, attr->poolCeil, outputTensors[0], outputTensors[1], &workspaceSize, &executor); + + checkRet(ret); + + if (workspaceSize > 0) + { + mallocWorkSpace(workspaceSize); + } + + ret = aclnnMaxPool2dWithIndices(workspaceAddr, workspaceSize, executor, aclstream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnMaxPool2dWithIndices failed. ERROR: %d\n", name.c_str(), ret); return); + + syncRun(); + + aclDestroyIntArray(strides); + aclDestroyIntArray(pads); + aclDestroyIntArray(dilations); + aclDestroyIntArray(kernel_size); + + return; + } + + AvgpoolOpRunner::AvgpoolOpRunner() : BaseOpRunner("Avgpool") + { + use_nchw = true; + } + + void AvgpoolOpRunner::executeOp(std::unordered_map::iterator &it) + { + + aclIntArray *strides = nullptr; + aclIntArray *pads = nullptr; + aclIntArray *kernel_size = nullptr; + + auto attr = dynamic_cast(op_attr.get()); + kernel_size = aclCreateIntArray(attr->kernel_size.data(), 2); + strides = aclCreateIntArray(attr->poolStrides.data(), 2); + pads = aclCreateIntArray(attr->poolPads.data(), 2); + ret = aclnnAvgPool2dGetWorkspaceSize(inputTensors[0], kernel_size, strides, pads, attr->poolCeil, attr->countIncludePad, attr->divisorOverride, attr->divisorOverride, outputTensors[0], &workspaceSize, &executor); + + checkRet(ret); + + if (workspaceSize > 0) + { + mallocWorkSpace(workspaceSize); + } + + ret = aclnnAvgPool2d(workspaceAddr, workspaceSize, executor, aclstream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnAvgPool2d failed. ERROR: %d\n", name.c_str(), ret); return); + + syncRun(); + + aclDestroyIntArray(strides); + aclDestroyIntArray(pads); + aclDestroyIntArray(kernel_size); + + return; + } + + MaxpoolBackwardOpRunner::MaxpoolBackwardOpRunner() : BaseOpRunner("MaxpoolBackward") + { + use_nchw = true; + } + + void MaxpoolBackwardOpRunner::executeOp(std::unordered_map::iterator &it) + { + + aclIntArray *strides = nullptr; + aclIntArray *pads = nullptr; + aclIntArray *dilations = nullptr; + aclIntArray *kernel_size = nullptr; + + auto attr = dynamic_cast(op_attr.get()); + kernel_size = aclCreateIntArray(attr->kernel_size.data(), 2); + strides = aclCreateIntArray(attr->poolStrides.data(), 2); + pads = aclCreateIntArray(attr->poolPads.data(), 2); + dilations = aclCreateIntArray(attr->poolDilations.data(), 2); + ret = aclnnMaxPool2dWithIndicesBackwardGetWorkspaceSize(inputTensors[0], inputTensors[1], inputTensors[2], kernel_size, strides, pads, dilations, attr->poolCeil, outputTensors[0], &workspaceSize, &executor); + + checkRet(ret); + + if (workspaceSize > 0) + { + mallocWorkSpace(workspaceSize); + } + + ret = aclnnMaxPool2dWithIndicesBackward(workspaceAddr, workspaceSize, executor, aclstream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnMaxPool2dWithIndicesBackward failed. ERROR: %d\n", name.c_str(), ret); return); + + syncRun(); + + aclDestroyIntArray(strides); + aclDestroyIntArray(pads); + aclDestroyIntArray(dilations); + aclDestroyIntArray(kernel_size); + + return; + } + + AvgpoolBackwardOpRunner::AvgpoolBackwardOpRunner() : BaseOpRunner("AvgpoolBackward") + { + use_nchw = true; + } + + void AvgpoolBackwardOpRunner::executeOp(std::unordered_map::iterator &it) + { + aclIntArray *strides = nullptr; + aclIntArray *pads = nullptr; + aclIntArray *kernel_size = nullptr; + + auto attr = dynamic_cast(op_attr.get()); + kernel_size = aclCreateIntArray(attr->kernel_size.data(), 2); + strides = aclCreateIntArray(attr->poolStrides.data(), 2); + pads = aclCreateIntArray(attr->poolPads.data(), 2); + ret = aclnnAvgPool2dBackwardGetWorkspaceSize(inputTensors[0], inputTensors[1], kernel_size, strides, pads, attr->countIncludePad, attr->divisorOverride, attr->divisorOverride, attr->poolCeil, outputTensors[0], &workspaceSize, &executor); + + checkRet(ret); + + if (workspaceSize > 0) + { + mallocWorkSpace(workspaceSize); + } + + ret = aclnnAvgPool2dBackward(workspaceAddr, workspaceSize, executor, aclstream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnAvgPool2dBackward failed. ERROR: %d\n", name.c_str(), ret); return); + + syncRun(); + + aclDestroyIntArray(strides); + aclDestroyIntArray(pads); + aclDestroyIntArray(kernel_size); + + return; + } + +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/pool_op_acl.h b/python/jittor/extern/acl/aclops/pool_op_acl.h new file mode 100644 index 00000000..5116314a --- /dev/null +++ b/python/jittor/extern/acl/aclops/pool_op_acl.h @@ -0,0 +1,46 @@ +#pragma once +#include "utils.h" +#include "base_op.h" + +namespace jittor +{ + class MaxpoolOpRunner : public BaseOpRunner + { + + protected: + void executeOp(std::unordered_map::iterator &it) override; + + public: + MaxpoolOpRunner(); + }; + + class AvgpoolOpRunner : public BaseOpRunner + { + + protected: + void executeOp(std::unordered_map::iterator &it) override; + + public: + AvgpoolOpRunner(); + }; + + class MaxpoolBackwardOpRunner : public BaseOpRunner + { + + protected: + void executeOp(std::unordered_map::iterator &it) override; + + public: + MaxpoolBackwardOpRunner(); + }; + + class AvgpoolBackwardOpRunner : public BaseOpRunner + { + + protected: + void executeOp(std::unordered_map::iterator &it) override; + + public: + AvgpoolBackwardOpRunner(); + }; +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/random_op_acl.cc b/python/jittor/extern/acl/aclops/random_op_acl.cc new file mode 100644 index 00000000..2fb18eba --- /dev/null +++ b/python/jittor/extern/acl/aclops/random_op_acl.cc @@ -0,0 +1,82 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include "common.h" +#include "op.h" +#include "acl_jittor.h" +#include "ops/random_op.h" +#include "ops/reduce_op.h" +#include "ops/binary_op.h" +#include "ops/broadcast_to_op.h" +#include "ops/transpose_op.h" +#include "ops/array_op.h" +#include "ops/code_op.h" +#include "fused_op.h" +#include "ops/unary_op.h" +#include "ops/ternary_op.h" +#include "executor.h" +#include "misc/cuda_flags.h" +#include "mem/allocator.h" +#include "op_compiler.h" +#include "ops/op_register.h" +#include "opt/tuner_manager.h" +#include "utils/str_utils.h" +#include "aclnn/aclnn.h" +#include "random_op_acl.h" + +namespace jittor +{ + RandomOpRunner::RandomOpRunner() : BaseOpRunner("RandomUniform") + { + name = "RandomUniform"; + } + + RandomOpRunner::RandomOpRunner(const string &_name) : BaseOpRunner(_name) + { + name = _name; + } + + void RandomOpRunner::executeOp(std::unordered_map::iterator &it) + { + auto attr = dynamic_cast(op_attr.get()); + if (name == "RandomUniform") + { + ret = aclnnInplaceUniformGetWorkspaceSize(outputTensors[0], 0.0, 1.0, attr->seed, attr->offset, &workspaceSize, &executor); + + checkRet(ret); + + if (workspaceSize > 0) + { + mallocWorkSpace(workspaceSize); + } + + ret = aclnnInplaceUniform(workspaceAddr, workspaceSize, executor, aclstream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnInplaceUniform failed. ERROR: %d\n", name.c_str(), ret); return); + } + else if (name == "RandomNormal") + { + ret = aclnnInplaceNormalGetWorkspaceSize(outputTensors[0], 0.0, 1.0, attr->seed, attr->offset, &workspaceSize, &executor); + + checkRet(ret); + + if (workspaceSize > 0) + { + mallocWorkSpace(workspaceSize); + } + + ret = aclnnInplaceNormal(workspaceAddr, workspaceSize, executor, aclstream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnInplaceNormal failed. ERROR: %d\n", name.c_str(), ret); return); + } + else + { + LOGf << "Not supported random type : " << name; + } + syncRun(); + return; + } +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/random_op_acl.h b/python/jittor/extern/acl/aclops/random_op_acl.h new file mode 100644 index 00000000..af60dc37 --- /dev/null +++ b/python/jittor/extern/acl/aclops/random_op_acl.h @@ -0,0 +1,18 @@ +#pragma once +#include "utils.h" +#include "base_op.h" + +namespace jittor +{ + class RandomOpRunner : public BaseOpRunner + { + + protected: + string name; // special to random op + void executeOp(std::unordered_map::iterator &it) override; + + public: + RandomOpRunner(); + RandomOpRunner(const string &name); + }; +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/reduce_op_acl.cc b/python/jittor/extern/acl/aclops/reduce_op_acl.cc new file mode 100644 index 00000000..e0c0b069 --- /dev/null +++ b/python/jittor/extern/acl/aclops/reduce_op_acl.cc @@ -0,0 +1,127 @@ +#include +#include +#include +#include +#include +#include +#include +#include "common.h" +#include "op.h" +#include "acl_jittor.h" +#include "ops/random_op.h" +#include "ops/reduce_op.h" +#include "ops/binary_op.h" +#include "ops/broadcast_to_op.h" +#include "ops/transpose_op.h" +#include "ops/array_op.h" +#include "ops/code_op.h" +#include "fused_op.h" +#include "ops/unary_op.h" +#include "ops/ternary_op.h" +#include "executor.h" +#include "misc/cuda_flags.h" +#include "mem/allocator.h" +#include "op_compiler.h" +#include "ops/op_register.h" +#include "opt/tuner_manager.h" +#include "utils/str_utils.h" +#include "aclnn/aclnn.h" +#include "reduce_op_acl.h" + +namespace jittor +{ + ReduceOpRunner::ReduceOpRunner() : BaseOpRunner("reduce") + { + use_nchw = false; + } + + void ReduceOpRunner::setupOutputDesc() + { + auto output_num = out_.size(); + + for (int output_idx = 0; output_idx < output_num; output_idx++) + { + std::vector shape; + for (int j = 0; j < out_[output_idx]->shape.size(); j++) + { + shape.push_back(out_[output_idx]->shape[j]); + } + outputShapes.push_back(shape); + } + + attr = dynamic_cast(op_attr.get()); + dim = aclCreateIntArray(attr->axes.data(), attr->axes.size()); + keepdims = attr->keepdims; + + if (op_idx < 13) + { + if (attr->axes.size() == in_[0]->shape.size()) + outputShapes[0] = {}; + } + + for (int idx = 0; idx < output_num; idx++) + { + outputTensors.push_back(nullptr); + auto ret = CreateAclTensor(outputShapes[idx], out_[idx]->mem_ptr, out_[idx]->size, get_dtype(out_[idx]->dtype()), &outputTensors[idx], use_nchw); + CHECK_RET(ret == ACL_SUCCESS, return); + } + } + + void ReduceOpRunner::executeOp(std::unordered_map::iterator &it) + { + switch (op_idx) + { + case 9: + { + ret = aclnnReduceSumGetWorkspaceSize(inputTensors[0], dim, keepdims, get_dtype(out_[0]->dtype()), outputTensors[0], &workspaceSize, &executor); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnReduceSumGetWorkspaceSize failed. ERROR: %d\n", name.c_str(), ret); return); + if (workspaceSize > 0) + { + mallocWorkSpace(workspaceSize); + } + ret = aclnnReduceSum(workspaceAddr, workspaceSize, executor, aclstream); + break; + } + case 10: + { + ret = aclnnMeanGetWorkspaceSize(inputTensors[0], dim, keepdims, get_dtype(out_[0]->dtype()), outputTensors[0], &workspaceSize, &executor); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnMeanGetWorkspaceSize failed. ERROR: %d\n", name.c_str(), ret); return); + if (workspaceSize > 0) + { + mallocWorkSpace(workspaceSize); + } + ret = aclnnMean(workspaceAddr, workspaceSize, executor, aclstream); + break; + } + case 11: + { + ret = aclnnAmaxGetWorkspaceSize(inputTensors[0], dim, keepdims, outputTensors[0], &workspaceSize, &executor); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnAmaxGetWorkspaceSize failed. ERROR: %d\n", name.c_str(), ret); return); + if (workspaceSize > 0) + { + mallocWorkSpace(workspaceSize); + } + ret = aclnnAmax(workspaceAddr, workspaceSize, executor, aclstream); + break; + } + case 12: + { + ret = aclnnAminGetWorkspaceSize(inputTensors[0], dim, keepdims, outputTensors[0], &workspaceSize, &executor); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnAminGetWorkspaceSize failed. ERROR: %d\n", name.c_str(), ret); return); + if (workspaceSize > 0) + { + mallocWorkSpace(workspaceSize); + } + ret = aclnnAmin(workspaceAddr, workspaceSize, executor, aclstream); + break; + } + default: + { + LOGir << "no such reduce!!"; + exit(-1); + } + } + syncRun(); + return; + } +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/reduce_op_acl.h b/python/jittor/extern/acl/aclops/reduce_op_acl.h new file mode 100644 index 00000000..f4953652 --- /dev/null +++ b/python/jittor/extern/acl/aclops/reduce_op_acl.h @@ -0,0 +1,21 @@ +#pragma once +#include "utils.h" +#include "base_op.h" + +namespace jittor +{ + struct ReduceOpRunner : public BaseOpRunner + { + int op_idx; // Specific to reduce operations + + ReduceOpRunner(); + + protected: + ReduceAttr *attr; + aclIntArray *dim; + bool keepdims; + + void setupOutputDesc() override; + void executeOp(std::unordered_map::iterator &it) override; + }; +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/relu_op.py b/python/jittor/extern/acl/aclops/relu_op.py new file mode 100644 index 00000000..def998e2 --- /dev/null +++ b/python/jittor/extern/acl/aclops/relu_op.py @@ -0,0 +1,115 @@ +import os +from jittor_utils import env_or_try_find +import jittor_utils +import ctypes +import glob +import jittor.compiler as compiler +import jittor as jt +import math +import numpy as np + +from typing import Union +from collections.abc import Sequence, Iterable + + +def relu_cmd(name: str, + inputs: list, + output_dtypes: list = None, + output_shapes: list = None, + attr_code: str = "", + attr_header: str = "", + outputs: list = None): + attr_header = "\nnamespace jittor{" + attr_header + "}\n" + + cuda_header = ''' + #include "acl/aclops/aclops.h" + ''' + outputs_ = [] + if outputs is not None: + outputs_ = outputs + else: + assert output_dtypes is not None + assert output_shapes is not None + assert len(output_dtypes) == len(output_shapes) + for i in range(len(output_shapes)): + outputs_.append(jt.empty(output_shapes[i], output_dtypes[i])) + input_code = '' + for i in range(len(inputs)): + input_code += f"op.add(in{i}, true);\n" + + output_code = '' + for i in range(len(outputs_)): + output_code += f"op.add(out{i}, false);\n" + return jt.code(outputs=outputs_, + inputs=inputs, + cuda_header=attr_header + cuda_header, + cuda_src=f""" + + // aclop + {name}OpRunner op; + {input_code} + {output_code} + {attr_code} + op.run();""") + + +class ReLUACL(jt.Function): + + def __init__(self): + super(ReLUACL, self).__init__() + + def execute(self, x): + x = x.float32() + self.input = x + result = relu_cmd("Unary", [x], + output_dtypes=[x.dtype], + output_shapes=[x.shape], + attr_code="op.name=\"ReLU\";")[0] + return result + + def grad(self, grad_output): + mask = relu_cmd("Binary", + [self.input, jt.zeros(self.input.shape)], + output_dtypes=[self.input.dtype], + output_shapes=[self.input.shape], + attr_code="op.name=\"Greater\";")[0] + grad_input = relu_cmd("Binary", [grad_output, mask], + output_dtypes=[grad_output.dtype], + output_shapes=[grad_output.shape], + attr_code="op.name=\"Mul\";")[0] + return grad_input + +class LeakyReLUACL(jt.Function): + + def __init__(self): + super(LeakyReLUACL, self).__init__() + + def execute(self, x, negative_slope=0.01): + x = x.float32() + self.input = x + attr_code = f""" + op.jt_name = "leakyrelu"; + LeakyReluAttr *attr = new LeakyReluAttr(); + attr->negativeSlope = {negative_slope}; + op.op_attr.reset(attr); + """ + result = relu_cmd("LeakyReLU", [x], + output_dtypes=[x.dtype], + output_shapes=[x.shape], + attr_code=attr_code)[0] + self.negative_slope = negative_slope + return result + + def grad(self, grad_output): + attr_code = f""" + op.jt_name = "leakyrelubackward"; + LeakyReluAttr *attr = new LeakyReluAttr(); + attr->negativeSlope = {self.negative_slope}; + attr->selfIsResult = false; + op.op_attr.reset(attr); + """ + grad_input = relu_cmd("LeakyReLUBackward", [grad_output, self.input], + output_dtypes=[grad_output.dtype], + output_shapes=[grad_output.shape], + attr_code=attr_code)[0] + return grad_input diff --git a/python/jittor/extern/acl/aclops/relu_op_acl.cc b/python/jittor/extern/acl/aclops/relu_op_acl.cc new file mode 100644 index 00000000..9abeb711 --- /dev/null +++ b/python/jittor/extern/acl/aclops/relu_op_acl.cc @@ -0,0 +1,90 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include "common.h" +#include "op.h" +#include "acl_jittor.h" +#include "ops/random_op.h" +#include "ops/reduce_op.h" +#include "ops/binary_op.h" +#include "ops/broadcast_to_op.h" +#include "ops/transpose_op.h" +#include "ops/array_op.h" +#include "ops/code_op.h" +#include "fused_op.h" +#include "ops/unary_op.h" +#include "ops/ternary_op.h" +#include "executor.h" +#include "misc/cuda_flags.h" +#include "mem/allocator.h" +#include "op_compiler.h" +#include "ops/op_register.h" +#include "opt/tuner_manager.h" +#include "utils/str_utils.h" +#include "aclnn/aclnn.h" +#include "relu_op_acl.h" + +namespace jittor +{ + LeakyReLUOpRunner::LeakyReLUOpRunner() : BaseOpRunner("LeakyReLU") + { + } + + void LeakyReLUOpRunner::executeOp(std::unordered_map::iterator &it) + { + aclScalar *negativeSlope = nullptr; + + auto attr = dynamic_cast(op_attr.get()); + negativeSlope = aclCreateScalar(&attr->negativeSlope, aclDataType::ACL_FLOAT); + ret = aclnnLeakyReluGetWorkspaceSize(inputTensors[0], negativeSlope, outputTensors[0], &workspaceSize, &executor); + + checkRet(ret); + + if (workspaceSize > 0) + { + mallocWorkSpace(workspaceSize); + } + + ret = aclnnLeakyRelu(workspaceAddr, workspaceSize, executor, aclstream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnLeakyRelu failed. ERROR: %d\n", name.c_str(), ret); return); + + syncRun(); + + aclDestroyScalar(negativeSlope); + return; + } + + LeakyReLUBackwardOpRunner::LeakyReLUBackwardOpRunner() : BaseOpRunner("LeakyReLUBackward") + { + } + + void LeakyReLUBackwardOpRunner::executeOp(std::unordered_map::iterator &it) + { + aclScalar *negativeSlope = nullptr; + + auto attr = dynamic_cast(op_attr.get()); + negativeSlope = aclCreateScalar(&attr->negativeSlope, aclDataType::ACL_FLOAT); + ret = aclnnLeakyReluBackwardGetWorkspaceSize(inputTensors[0], inputTensors[1], negativeSlope, attr->selfIsResult, outputTensors[0], &workspaceSize, &executor); + + checkRet(ret); + + if (workspaceSize > 0) + { + mallocWorkSpace(workspaceSize); + } + + ret = aclnnLeakyReluBackward(workspaceAddr, workspaceSize, executor, aclstream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnLeakyReluBackward failed. ERROR: %d\n", name.c_str(), ret); return); + + syncRun(); + + aclDestroyScalar(negativeSlope); + return; + } + +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/relu_op_acl.h b/python/jittor/extern/acl/aclops/relu_op_acl.h new file mode 100644 index 00000000..c436dd1e --- /dev/null +++ b/python/jittor/extern/acl/aclops/relu_op_acl.h @@ -0,0 +1,27 @@ +#pragma once +#include "utils.h" +#include "base_op.h" + +namespace jittor +{ + class LeakyReLUOpRunner : public BaseOpRunner + { + + protected: + void executeOp(std::unordered_map::iterator &it) override; + + public: + LeakyReLUOpRunner(); + }; + + class LeakyReLUBackwardOpRunner : public BaseOpRunner + { + + protected: + void executeOp(std::unordered_map::iterator &it) override; + + public: + LeakyReLUBackwardOpRunner(); + }; + +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/rope_op.py b/python/jittor/extern/acl/aclops/rope_op.py new file mode 100644 index 00000000..71269cf1 --- /dev/null +++ b/python/jittor/extern/acl/aclops/rope_op.py @@ -0,0 +1,84 @@ +import os +from jittor_utils import env_or_try_find +import jittor_utils +import ctypes +import glob +import jittor.compiler as compiler +import jittor as jt +import math +import numpy as np + +from typing import Union +from collections.abc import Sequence, Iterable + + +def rope_cmd(name: str, + inputs: list, + output_dtypes: list = None, + output_shapes: list = None, + attr_code: str = "", + attr_header: str = "", + outputs: list = None): + attr_header = "\nnamespace jittor{" + attr_header + "}\n" + + cuda_header = ''' + #include "acl/aclops/aclops.h" + ''' + outputs_ = [] + if outputs is not None: + outputs_ = outputs + else: + assert output_dtypes is not None + assert output_shapes is not None + assert len(output_dtypes) == len(output_shapes) + for i in range(len(output_shapes)): + outputs_.append(jt.empty(output_shapes[i], output_dtypes[i])) + input_code = '' + for i in range(len(inputs)): + input_code += f"op.add(in{i}, true);\n" + + output_code = '' + for i in range(len(outputs_)): + output_code += f"op.add(out{i}, false);\n" + return jt.code(outputs=outputs_, + inputs=inputs, + cuda_header=attr_header + cuda_header, + cuda_src=f""" + + // aclop + {name}OpRunner op; + {input_code} + {output_code} + {attr_code} + op.run();""") + + +class RopeACL(jt.Function): + + def __init__(self): + super(RopeACL, self).__init__() + + def execute(self, xq, xk, freqs_cis, freq_cos, freq_sin): + attr_code = f""" + op.jt_name = "RotaryPosEmb"; + """ + if freqs_cis is not None: + freq_cos = freqs_cis[..., 0] + freq_sin = freqs_cis[..., 1] + else: + assert freq_cos is not None and freq_sin is not None + inputs = [xq, xk, freq_cos, freq_sin] + results = rope_cmd("RotaryPosEmb", + inputs, + output_dtypes=[ + xq.dtype, + ], + output_shapes=[ + xq.shape, + ], + attr_code=attr_code) + results[0].sync() + return inputs[0], inputs[1] + + def grad(self, grad_output): + return grad_output diff --git a/python/jittor/extern/acl/aclops/rope_op_acl.cc b/python/jittor/extern/acl/aclops/rope_op_acl.cc new file mode 100644 index 00000000..9be85716 --- /dev/null +++ b/python/jittor/extern/acl/aclops/rope_op_acl.cc @@ -0,0 +1,57 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include "common.h" +#include "op.h" +#include "acl_jittor.h" +#include "ops/random_op.h" +#include "ops/reduce_op.h" +#include "ops/binary_op.h" +#include "ops/broadcast_to_op.h" +#include "ops/transpose_op.h" +#include "ops/array_op.h" +#include "ops/code_op.h" +#include "fused_op.h" +#include "ops/unary_op.h" +#include "ops/ternary_op.h" +#include "executor.h" +#include "misc/cuda_flags.h" +#include "mem/allocator.h" +#include "op_compiler.h" +#include "ops/op_register.h" +#include "opt/tuner_manager.h" +#include "utils/str_utils.h" +#include "aclnn/aclnn.h" +#include "rope_op_acl.h" + +namespace jittor +{ + RotaryPosEmbOpRunner::RotaryPosEmbOpRunner() : BaseOpRunner("RotaryPosEmb") + { + } + + void RotaryPosEmbOpRunner::executeOp(std::unordered_map::iterator &it) + { + ret = aclnnApplyRotaryPosEmbGetWorkspaceSize(inputTensors[0], inputTensors[1], inputTensors[2], inputTensors[3], (int64_t)1, &workspaceSize, &executor); + + checkRet(ret); + + if (workspaceSize > 0) + { + mallocWorkSpace(workspaceSize); + } + + ret = aclnnApplyRotaryPosEmb(workspaceAddr, workspaceSize, executor, aclstream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnApplyRotaryPosEmb failed. ERROR: %d\n", name.c_str(), ret); return); + + syncRun(); + + return; + } + +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/rope_op_acl.h b/python/jittor/extern/acl/aclops/rope_op_acl.h new file mode 100644 index 00000000..0f1b2996 --- /dev/null +++ b/python/jittor/extern/acl/aclops/rope_op_acl.h @@ -0,0 +1,17 @@ +#pragma once +#include "utils.h" +#include "base_op.h" + +namespace jittor +{ + class RotaryPosEmbOpRunner : public BaseOpRunner + { + + protected: + void executeOp(std::unordered_map::iterator &it) override; + + public: + RotaryPosEmbOpRunner(); + }; + +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/setitem_op.py b/python/jittor/extern/acl/aclops/setitem_op.py new file mode 100644 index 00000000..58eb0f62 --- /dev/null +++ b/python/jittor/extern/acl/aclops/setitem_op.py @@ -0,0 +1,356 @@ +import os +from jittor_utils import env_or_try_find +import jittor_utils +import ctypes +import glob +import jittor.compiler as compiler +import jittor as jt +import math +import numpy as np + +from typing import Union +from collections.abc import Sequence, Iterable + + +def setitem_cmd(name: str, + inputs: list, + output_dtypes: list = None, + output_shapes: list = None, + attr_code: str = "", + attr_header: str = "", + outputs: list = None): + attr_header = "\nnamespace jittor{" + attr_header + "}\n" + + cuda_header = ''' + #include "acl/aclops/aclops.h" + ''' + outputs_ = [] + if outputs is not None: + outputs_ = outputs + else: + assert output_dtypes is not None + assert output_shapes is not None + assert len(output_dtypes) == len(output_shapes) + for i in range(len(output_shapes)): + outputs_.append(jt.empty(output_shapes[i], output_dtypes[i])) + input_code = '' + for i in range(len(inputs)): + input_code += f"op.add(in{i}, true);\n" + + output_code = '' + for i in range(len(outputs_)): + output_code += f"op.add(out{i}, false);\n" + return jt.code(outputs=outputs_, + inputs=inputs, + cuda_header=attr_header + cuda_header, + cuda_src=f""" + + // aclop + {name}OpRunner op; + {input_code} + {output_code} + {attr_code} + op.run();""") + + +def setitem_forward(name: str, + inputs: list, + output_dtypes: list = None, + output_shapes: list = None, + attr_code: str = "", + attr_header: str = "", + outputs: list = None, + extra_data: dict = {}): + attr_header = "\nnamespace jittor{" + attr_header + "}\n" + + cuda_header = ''' + #include "acl/aclops/aclops.h" + ''' + outputs_ = [] + if outputs is not None: + outputs_ = outputs + else: + assert output_dtypes is not None + assert output_shapes is not None + assert len(output_dtypes) == len(output_shapes) + for i in range(len(output_shapes)): + outputs_.append(jt.empty(output_shapes[i], output_dtypes[i])) + input_code = '' + for i in range(len(inputs)): + input_code += f"op.add(in{i}, true);\n" + + output_code = '' + for i in range(len(outputs_)): + output_code += f"op.add(out{i}, false);\n" + return jt.code(outputs=outputs_, + inputs=inputs, + cuda_header=attr_header + cuda_header, + cuda_src=f""" + // aclop + {name}OpRunner op; + {input_code} + op.add(out0, false); + {attr_code} + op.run();""", + data=extra_data) + + +def caculate_shape(tensors): + if isinstance(tensors, jt.Var): + # tensors = tensors[0] + return tensors.shape + elif isinstance(tensors, (int, float)): + return [] + elif isinstance(tensors, (list, tuple)): + # return [caculate_shape(tensor) for tensor in tensors] + sub_shape = caculate_shape(tensors[0]) + return [len(tensors)] + sub_shape + else: + assert False, f"not implemented for {type(tensors)}" + + +def can_broadcast_and_shape(shape1, shape2): + """ + 检查两个张量是否可以广播,并返回广播后的形状。 + + 参数: + - shape1: 第一个张量的形状(tuple 或 list) + - shape2: 第二个张量的形状(tuple 或 list) + + 返回: + - can_broadcast: 布尔值,表示是否可以广播 + - broadcast_shape: 如果可以广播,返回广播后的形状;否则返回 None + """ + # 将形状转换为元组,以防输入是列表 + shape1 = tuple(shape1) + shape2 = tuple(shape2) + + # 使两个形状的长度一致,通过在前面补1 + len1, len2 = len(shape1), len(shape2) + if len1 < len2: + shape1 = (1, ) * (len2 - len1) + shape1 + elif len2 < len1: + shape2 = (1, ) * (len1 - len2) + shape2 + + broadcast_shape = [] + + # 从最后一维开始检查每一维度 + for dim1, dim2 in zip(shape1, shape2): + if dim1 == dim2: + broadcast_shape.append(dim1) + elif dim1 == 1: + broadcast_shape.append(dim2) + elif dim2 == 1: + broadcast_shape.append(dim1) + else: + # 如果在某一维度上不兼容,则不能广播 + return False, None + + return True, tuple(broadcast_shape) + + +class SetItemACL(jt.Function): + + def __init__(self): + self.type_ = 'notype' + self.value_var = True + + def stride(self, x, dim): + stride = 1 + for i in range(dim + 1, len(x.shape)): + stride *= x.shape[i] + return stride + + def execute(self, x, slices, value): + self.x_shape = x.shape + self.input_slice = slices + if not isinstance(value, jt.Var): + self.value_var = False + if isinstance(slices, jt.Var): + if slices.dtype == "bool": + slices_len = slices.sum().item() + if slices_len == 0: + return x + if isinstance(value, int) or isinstance(value, float): + value = jt.full((slices_len, ), value, dtype=x.dtype) + assert slices.shape == x.shape, "setitem shape not match" + assert len(value.shape) == 1, "value shape must be 1D" + assert value.shape[ + 0] == slices_len, "value shape length must be equal to slices sum" + self.type_ = 'mask' + self.value_shape = value.shape + inputs = [slices, value] + outputs = [x.clone()] + attr_code = f""" + op.jt_name = "inplacemaskedscatter"; + """ + result = setitem_cmd("InplaceMaskedScatter", + inputs=inputs, + outputs=outputs, + attr_code=attr_code)[0] + return result + + # assert isinstance(value,jt.Var), "value must be jt.Var" + # self.value_shape = value.shape + if not isinstance(slices, tuple): + slices = (slices, ) + slices = list(slices) + for i, s in enumerate(slices): + if isinstance(s, int) and s < 0: + slices[i] = x.shape[i] + s + slices = tuple(slices) + slices_list = list(slices) + #check slices contains slice type + contains_slice = False + for s in slices: + if not isinstance(s, jt.Var) and (isinstance(s, slice) + or s == Ellipsis): + contains_slice = True + break + if not contains_slice: + indices = [] + value_shape = [] + slices_len = len(slices) + boardcast_shape = caculate_shape(slices_list[0]) + for ii in range(1, len(slices)): + dd, boardcast_shape = can_broadcast_and_shape( + boardcast_shape, caculate_shape(slices_list[ii])) + assert dd is True, "can not broadcast" + value_shape = boardcast_shape + value_shape += x.shape[slices_len:] + if value_shape == []: + value_shape = [1] + if isinstance(value, int) or isinstance(value, float): + value = jt.full(value_shape, value) + self.value_shape = value_shape + for ii in slices: + indices.append(jt.Var(ii).int32()) + if isinstance(slices[0], + jt.Var) or isinstance(slices[0], int) or isinstance( + slices[0], list) or isinstance(slices[0], tuple): + self.indices = indices + self.type_ = 'index' + attr_code = f""" + op.jt_name = "indexputimpl"; + """ + inputs = [value] + indices + outputs = [x.clone()] + result = setitem_cmd("IndexPutImpl", + inputs=inputs, + outputs=outputs, + attr_code=attr_code)[0] + # result.sync() + return result + assert "not support" + assert contains_slice, "slice type error" + x_dim = len(x.shape) + slices = list(slices) + for s in slices: + if not isinstance(s, jt.Var) and s == Ellipsis: + slices = slices[:slices.index(s)] + [ + slice(None, None, None) + ] * (x_dim - len(slices) + 1) + slices[slices.index(s) + 1:] + break + slices = tuple(slices) + self.input_slice = slices + if len(slices) < x_dim: + slices += (slice(None, None, None), ) * (x_dim - len(slices)) + sizes = [] + #适配华为奇怪的要求,最后一个维度的step必须是1 + expand_dim = False + if isinstance(slices[-1], slice): + if slices[-1].step is not None and slices[-1].step != 1: + slices = slices + (slice(None, None, None), ) + expand_dim = True + + elif isinstance(slices[-1], int): + #注意最后一个维度是数字 + slices = slices + (slice(None, None, None), ) + expand_dim = True + # value = value.unsqueeze(-1) + else: + assert False, "not supported" + x_shape = list(x.shape) + if expand_dim: + x_shape.append(1) + x = x.unsqueeze(-1) + value = value.unsqueeze(-1) + + squeeze_dims = [] + if isinstance(value, jt.Var): + for dim, s in enumerate(slices): + if isinstance(s, int): + s = slice(s, s + 1, 1) + squeeze_dims.append(dim) + + for dim in squeeze_dims: + value = value.unsqueeze(dim) + + extra_data = {} + if len(slices): + extra_data["a"] = len(slices) + for dim, s in enumerate(slices): + if isinstance(s, int): + s = slice(s, s + 1, 1) + if isinstance(s, jt.Var): + assert False, "jt.Var not supported" + start, stop, step = s.indices(x_shape[dim]) + size = (stop - start - 1) // step + 1 + sizes.append(size) + extra_data[str(dim * 3)] = start + extra_data[str(dim * 3 + 1)] = stop + extra_data[str(dim * 3 + 2)] = step + else: + extra_data["a"] = -1 + sizes = [1] + steps = [1] + if isinstance(value, int) or isinstance(value, float): + value = jt.full(sizes, value) + self.type_ = 'slicev2' + attr_code = """ + op.jt_name = "stridedsliceassignv2"; + StrideAttr *attr = new StrideAttr(); + int slice_dim = data["a"]; + + if(slice_dim == -1) { + attr->begins = {}; + attr->ends = {}; + attr->steps = {1}; + attr->axes = {}; + } else { + vector begins; + vector ends; + vector steps; + vector dims; + for(int dim = 0; dim < slice_dim; dim++) { + dims.push_back(dim); + begins.push_back(data[std::to_string(dim*3)]); + ends.push_back(data[std::to_string(dim*3+1)]); + steps.push_back(data[std::to_string(dim*3+2)]); + } + attr->begins = begins; + attr->ends = ends; + attr->steps = steps; + attr->axes = dims; + } + op.op_attr.reset(attr); + """ + self.value_shape = value.shape + inputs = [value] + outputs = [x.clone()] + result = setitem_forward("StridedSliceAssignV2", + inputs=inputs, + outputs=outputs, + attr_code=attr_code, + extra_data=extra_data)[0] + if expand_dim: + result = result.squeeze(-1) + # result.sync() + return result + + def grad(self, grad_output): + value_grad = None + if self.value_var: + value_grad = grad_output[self.input_slice] + grad_output[self.input_slice] = jt.zeros(self.value_shape) + return grad_output, None, value_grad diff --git a/python/jittor/extern/acl/aclops/setitem_op_acl.cc b/python/jittor/extern/acl/aclops/setitem_op_acl.cc new file mode 100644 index 00000000..1ed0e284 --- /dev/null +++ b/python/jittor/extern/acl/aclops/setitem_op_acl.cc @@ -0,0 +1,84 @@ +#include +#include +#include +#include +#include +#include +#include +#include "common.h" +#include "op.h" +#include "acl_jittor.h" +#include "ops/random_op.h" +#include "ops/reduce_op.h" +#include "ops/binary_op.h" +#include "ops/broadcast_to_op.h" +#include "ops/transpose_op.h" +#include "ops/array_op.h" +#include "ops/code_op.h" +#include "fused_op.h" +#include "ops/unary_op.h" +#include "ops/ternary_op.h" +#include "executor.h" +#include "misc/cuda_flags.h" +#include "mem/allocator.h" +#include "op_compiler.h" +#include "ops/op_register.h" +#include "opt/tuner_manager.h" +#include "utils/str_utils.h" +#include "aclnn/aclnn.h" +#include "setitem_op_acl.h" + +namespace jittor +{ + InplaceMaskedScatterOpRunner::InplaceMaskedScatterOpRunner() : BaseOpRunner("InplaceMaskedScatter") + { + } + + void InplaceMaskedScatterOpRunner::executeOp(std::unordered_map::iterator &it) + { + ret = aclnnInplaceMaskedScatterGetWorkspaceSize(outputTensors[0], inputTensors[0], inputTensors[1], &workspaceSize, &executor); + + checkRet(ret); + + if (workspaceSize > 0) + { + mallocWorkSpace(workspaceSize); + } + + ret = aclnnInplaceMaskedScatter(workspaceAddr, workspaceSize, executor, aclstream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnInplaceMaskedScatter failed. ERROR: %d\n", name.c_str(), ret); return); + + syncRun(); + return; + } + + IndexPutImplOpRunner::IndexPutImplOpRunner() : BaseOpRunner("IndexPutImpl") + { + } + + void IndexPutImplOpRunner::executeOp(std::unordered_map::iterator &it) + { + auto input_num = in_.size(); + std::vector indexTensorList = {}; + for (int i = 1; i < input_num; i++) + { + indexTensorList.push_back(inputTensors[i]); + } + auto indexTensorListInput = aclCreateTensorList(&indexTensorList[0], input_num - 1); + ret = aclnnIndexPutImplGetWorkspaceSize(outputTensors[0], indexTensorListInput, inputTensors[0], false, true, &workspaceSize, &executor); + + checkRet(ret); + + if (workspaceSize > 0) + { + mallocWorkSpace(workspaceSize); + } + + ret = aclnnIndexPutImpl(workspaceAddr, workspaceSize, executor, aclstream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnIndexPutImpl failed. ERROR: %d\n", name.c_str(), ret); return); + + syncRun(); + return; + } + +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/setitem_op_acl.h b/python/jittor/extern/acl/aclops/setitem_op_acl.h new file mode 100644 index 00000000..ddd73902 --- /dev/null +++ b/python/jittor/extern/acl/aclops/setitem_op_acl.h @@ -0,0 +1,26 @@ +#pragma once +#include "utils.h" +#include "base_op.h" + +namespace jittor +{ + class InplaceMaskedScatterOpRunner : public BaseOpRunner + { + + protected: + void executeOp(std::unordered_map::iterator &it) override; + + public: + InplaceMaskedScatterOpRunner(); + }; + + class IndexPutImplOpRunner : public BaseOpRunner + { + + protected: + void executeOp(std::unordered_map::iterator &it) override; + + public: + IndexPutImplOpRunner(); + }; +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/sigmoid_op.py b/python/jittor/extern/acl/aclops/sigmoid_op.py new file mode 100644 index 00000000..ed3f1240 --- /dev/null +++ b/python/jittor/extern/acl/aclops/sigmoid_op.py @@ -0,0 +1,85 @@ +import os +from jittor_utils import env_or_try_find +import jittor_utils +import ctypes +import glob +import jittor.compiler as compiler +import jittor as jt +import math +import numpy as np + +from typing import Union +from collections.abc import Sequence, Iterable + + +def sigmoid_cmd(name: str, + inputs: list, + output_dtypes: list = None, + output_shapes: list = None, + attr_code: str = "", + attr_header: str = "", + outputs: list = None): + attr_header = "\nnamespace jittor{" + attr_header + "}\n" + + cuda_header = ''' + #include "acl/aclops/aclops.h" + ''' + outputs_ = [] + if outputs is not None: + outputs_ = outputs + else: + assert output_dtypes is not None + assert output_shapes is not None + assert len(output_dtypes) == len(output_shapes) + for i in range(len(output_shapes)): + outputs_.append(jt.empty(output_shapes[i], output_dtypes[i])) + input_code = '' + for i in range(len(inputs)): + input_code += f"op.add(in{i}, true);\n" + + output_code = '' + for i in range(len(outputs_)): + output_code += f"op.add(out{i}, false);\n" + return jt.code(outputs=outputs_, + inputs=inputs, + cuda_header=attr_header + cuda_header, + cuda_src=f""" + + // aclop + {name}OpRunner op; + {input_code} + {output_code} + {attr_code} + op.run();""") + + +class SigmoidACL(jt.Function): + + def __init__(self): + super(SigmoidACL, self).__init__() + + def execute(self, x): + x = x.float32() + inputs = [x] + outputs = [jt.empty(x.shape, x.dtype)] + attr_code = f""" + op.jt_name = "sigmoid"; + """ + result = sigmoid_cmd("Sigmoid", + inputs=inputs, + outputs=outputs, + attr_code=attr_code)[0] + self.output = result + return result + + def grad(self, grad_output): + attr_code = f""" + op.jt_name = "sigmoidbackward"; + """ + inputs = [grad_output, self.output] + outputs = [jt.empty(grad_output.shape, grad_output.dtype)] + grad_input = sigmoid_cmd("SigmoidBackward", + inputs=inputs, + outputs=outputs, + attr_code=attr_code)[0] + return grad_input diff --git a/python/jittor/extern/acl/aclops/sigmoid_op_acl.cc b/python/jittor/extern/acl/aclops/sigmoid_op_acl.cc new file mode 100644 index 00000000..1014aecc --- /dev/null +++ b/python/jittor/extern/acl/aclops/sigmoid_op_acl.cc @@ -0,0 +1,80 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include "common.h" +#include "op.h" +#include "acl_jittor.h" +#include "ops/random_op.h" +#include "ops/reduce_op.h" +#include "ops/binary_op.h" +#include "ops/broadcast_to_op.h" +#include "ops/transpose_op.h" +#include "ops/array_op.h" +#include "ops/code_op.h" +#include "fused_op.h" +#include "ops/unary_op.h" +#include "ops/ternary_op.h" +#include "executor.h" +#include "misc/cuda_flags.h" +#include "mem/allocator.h" +#include "op_compiler.h" +#include "ops/op_register.h" +#include "opt/tuner_manager.h" +#include "utils/str_utils.h" +#include "aclnn/aclnn.h" +#include "sigmoid_op_acl.h" + +namespace jittor +{ + SigmoidOpRunner::SigmoidOpRunner() : BaseOpRunner("Sigmoid") + { + } + + void SigmoidOpRunner::executeOp(std::unordered_map::iterator &it) + { + ret = aclnnSigmoidGetWorkspaceSize(inputTensors[0], outputTensors[0], &workspaceSize, &executor); + + checkRet(ret); + + if (workspaceSize > 0) + { + mallocWorkSpace(workspaceSize); + } + + ret = aclnnSigmoid(workspaceAddr, workspaceSize, executor, aclstream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnSigmoid failed. ERROR: %d\n", name.c_str(), ret); return); + + syncRun(); + + return; + } + + SigmoidBackwardOpRunner::SigmoidBackwardOpRunner() : BaseOpRunner("SigmoidBackward") + { + } + + void SigmoidBackwardOpRunner::executeOp(std::unordered_map::iterator &it) + { + ret = aclnnSigmoidBackwardGetWorkspaceSize(inputTensors[0], inputTensors[1], outputTensors[0], &workspaceSize, &executor); + + checkRet(ret); + + if (workspaceSize > 0) + { + mallocWorkSpace(workspaceSize); + } + + ret = aclnnSigmoidBackward(workspaceAddr, workspaceSize, executor, aclstream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnSigmoidBackward failed. ERROR: %d\n", name.c_str(), ret); return); + + syncRun(); + + return; + } + +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/sigmoid_op_acl.h b/python/jittor/extern/acl/aclops/sigmoid_op_acl.h new file mode 100644 index 00000000..b175cd01 --- /dev/null +++ b/python/jittor/extern/acl/aclops/sigmoid_op_acl.h @@ -0,0 +1,27 @@ +#pragma once +#include "utils.h" +#include "base_op.h" + +namespace jittor +{ + class SigmoidOpRunner : public BaseOpRunner + { + + protected: + void executeOp(std::unordered_map::iterator &it) override; + + public: + SigmoidOpRunner(); + }; + + class SigmoidBackwardOpRunner : public BaseOpRunner + { + + protected: + void executeOp(std::unordered_map::iterator &it) override; + + public: + SigmoidBackwardOpRunner(); + }; + +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/silu_op.py b/python/jittor/extern/acl/aclops/silu_op.py new file mode 100644 index 00000000..30613b6a --- /dev/null +++ b/python/jittor/extern/acl/aclops/silu_op.py @@ -0,0 +1,85 @@ +import os +from jittor_utils import env_or_try_find +import jittor_utils +import ctypes +import glob +import jittor.compiler as compiler +import jittor as jt +import math +import numpy as np + +from typing import Union +from collections.abc import Sequence, Iterable + + +def silu_cmd(name: str, + inputs: list, + output_dtypes: list = None, + output_shapes: list = None, + attr_code: str = "", + attr_header: str = "", + outputs: list = None): + attr_header = "\nnamespace jittor{" + attr_header + "}\n" + + cuda_header = ''' + #include "acl/aclops/aclops.h" + ''' + outputs_ = [] + if outputs is not None: + outputs_ = outputs + else: + assert output_dtypes is not None + assert output_shapes is not None + assert len(output_dtypes) == len(output_shapes) + for i in range(len(output_shapes)): + outputs_.append(jt.empty(output_shapes[i], output_dtypes[i])) + input_code = '' + for i in range(len(inputs)): + input_code += f"op.add(in{i}, true);\n" + + output_code = '' + for i in range(len(outputs_)): + output_code += f"op.add(out{i}, false);\n" + return jt.code(outputs=outputs_, + inputs=inputs, + cuda_header=attr_header + cuda_header, + cuda_src=f""" + + // aclop + {name}OpRunner op; + {input_code} + {output_code} + {attr_code} + op.run();""") + + +class SiLUACL(jt.Function): + + def __init__(self): + super(SiLUACL, self).__init__() + + def execute(self, x): + x = x.float32() + inputs = [x] + self.input = x + outputs = [jt.empty(x.shape, x.dtype)] + attr_code = f""" + op.jt_name = "silu"; + """ + result = silu_cmd("SiLU", + inputs=inputs, + outputs=outputs, + attr_code=attr_code)[0] + return result + + def grad(self, grad_output): + attr_code = f""" + op.jt_name = "silubackward"; + """ + inputs = [grad_output, self.input] + outputs = [jt.empty(grad_output.shape, grad_output.dtype)] + grad_input = silu_cmd("SiLUBackward", + inputs=inputs, + outputs=outputs, + attr_code=attr_code)[0] + return grad_input diff --git a/python/jittor/extern/acl/aclops/silu_op_acl.cc b/python/jittor/extern/acl/aclops/silu_op_acl.cc new file mode 100644 index 00000000..846da6da --- /dev/null +++ b/python/jittor/extern/acl/aclops/silu_op_acl.cc @@ -0,0 +1,80 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include "common.h" +#include "op.h" +#include "acl_jittor.h" +#include "ops/random_op.h" +#include "ops/reduce_op.h" +#include "ops/binary_op.h" +#include "ops/broadcast_to_op.h" +#include "ops/transpose_op.h" +#include "ops/array_op.h" +#include "ops/code_op.h" +#include "fused_op.h" +#include "ops/unary_op.h" +#include "ops/ternary_op.h" +#include "executor.h" +#include "misc/cuda_flags.h" +#include "mem/allocator.h" +#include "op_compiler.h" +#include "ops/op_register.h" +#include "opt/tuner_manager.h" +#include "utils/str_utils.h" +#include "aclnn/aclnn.h" +#include "silu_op_acl.h" + +namespace jittor +{ + SiLUOpRunner::SiLUOpRunner() : BaseOpRunner("SiLU") + { + } + + void SiLUOpRunner::executeOp(std::unordered_map::iterator &it) + { + ret = aclnnSiluGetWorkspaceSize(inputTensors[0], outputTensors[0], &workspaceSize, &executor); + + checkRet(ret); + + if (workspaceSize > 0) + { + mallocWorkSpace(workspaceSize); + } + + ret = aclnnSilu(workspaceAddr, workspaceSize, executor, aclstream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnSilu failed. ERROR: %d\n", name.c_str(), ret); return); + + syncRun(); + + return; + } + + SiLUBackwardOpRunner::SiLUBackwardOpRunner() : BaseOpRunner("SiLUBackward") + { + } + + void SiLUBackwardOpRunner::executeOp(std::unordered_map::iterator &it) + { + ret = aclnnSiluBackwardGetWorkspaceSize(inputTensors[0], inputTensors[1], outputTensors[0], &workspaceSize, &executor); + + checkRet(ret); + + if (workspaceSize > 0) + { + mallocWorkSpace(workspaceSize); + } + + ret = aclnnSiluBackward(workspaceAddr, workspaceSize, executor, aclstream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnSiluBackward failed. ERROR: %d\n", name.c_str(), ret); return); + + syncRun(); + + return; + } + +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/silu_op_acl.h b/python/jittor/extern/acl/aclops/silu_op_acl.h new file mode 100644 index 00000000..abc52810 --- /dev/null +++ b/python/jittor/extern/acl/aclops/silu_op_acl.h @@ -0,0 +1,27 @@ +#pragma once +#include "utils.h" +#include "base_op.h" + +namespace jittor +{ + class SiLUOpRunner : public BaseOpRunner + { + + protected: + void executeOp(std::unordered_map::iterator &it) override; + + public: + SiLUOpRunner(); + }; + + class SiLUBackwardOpRunner : public BaseOpRunner + { + + protected: + void executeOp(std::unordered_map::iterator &it) override; + + public: + SiLUBackwardOpRunner(); + }; + +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/softmax_op.py b/python/jittor/extern/acl/aclops/softmax_op.py new file mode 100644 index 00000000..85ae9c72 --- /dev/null +++ b/python/jittor/extern/acl/aclops/softmax_op.py @@ -0,0 +1,92 @@ +import os +from jittor_utils import env_or_try_find +import jittor_utils +import ctypes +import glob +import jittor.compiler as compiler +import jittor as jt +import math +import numpy as np + +from typing import Union +from collections.abc import Sequence, Iterable + + +def softmax_cmd(name: str, + inputs: list, + output_dtypes: list = None, + output_shapes: list = None, + attr_code: str = "", + attr_header: str = "", + outputs: list = None): + attr_header = "\nnamespace jittor{" + attr_header + "}\n" + + cuda_header = ''' + #include "acl/aclops/aclops.h" + ''' + outputs_ = [] + if outputs is not None: + outputs_ = outputs + else: + assert output_dtypes is not None + assert output_shapes is not None + assert len(output_dtypes) == len(output_shapes) + for i in range(len(output_shapes)): + outputs_.append(jt.empty(output_shapes[i], output_dtypes[i])) + input_code = '' + for i in range(len(inputs)): + input_code += f"op.add(in{i}, true);\n" + + output_code = '' + for i in range(len(outputs_)): + output_code += f"op.add(out{i}, false);\n" + return jt.code(outputs=outputs_, + inputs=inputs, + cuda_header=attr_header + cuda_header, + cuda_src=f""" + + // aclop + {name}OpRunner op; + {input_code} + {output_code} + {attr_code} + op.run();""") + + +class SoftmaxACL(jt.Function): + + def __init__(self): + super(SoftmaxACL, self).__init__() + + def execute(self, x, dim): + x = x.float32() + inputs = [x] + outputs = [jt.empty(x.shape)] + self.dim = dim + attr_code = f""" + op.jt_name = "softmax"; + SoftmaxAttr *attr = new SoftmaxAttr(); + attr->dim = {dim}; + op.op_attr.reset(attr); + """ + result = softmax_cmd("Softmax", + inputs=inputs, + outputs=outputs, + attr_code=attr_code)[0] + self.output = result + return result + + def grad(self, grad_output): + attr_code = f""" + op.jt_name = "softmax"; + SoftmaxAttr *attr = new SoftmaxAttr(); + attr->dim = {self.dim}; + op.op_attr.reset(attr); + """ + inputs = [grad_output, self.output] + outputs = [jt.empty(grad_output.shape)] + grad_input = softmax_cmd("SoftmaxBackward", + inputs=inputs, + outputs=outputs, + attr_code=attr_code)[0] + return grad_input diff --git a/python/jittor/extern/acl/aclops/softmax_op_acl.cc b/python/jittor/extern/acl/aclops/softmax_op_acl.cc new file mode 100644 index 00000000..43ca950d --- /dev/null +++ b/python/jittor/extern/acl/aclops/softmax_op_acl.cc @@ -0,0 +1,82 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include "common.h" +#include "op.h" +#include "acl_jittor.h" +#include "ops/random_op.h" +#include "ops/reduce_op.h" +#include "ops/binary_op.h" +#include "ops/broadcast_to_op.h" +#include "ops/transpose_op.h" +#include "ops/array_op.h" +#include "ops/code_op.h" +#include "fused_op.h" +#include "ops/unary_op.h" +#include "ops/ternary_op.h" +#include "executor.h" +#include "misc/cuda_flags.h" +#include "mem/allocator.h" +#include "op_compiler.h" +#include "ops/op_register.h" +#include "opt/tuner_manager.h" +#include "utils/str_utils.h" +#include "aclnn/aclnn.h" +#include "softmax_op_acl.h" + +namespace jittor +{ + SoftmaxOpRunner::SoftmaxOpRunner() : BaseOpRunner("Softmax") + { + } + + void SoftmaxOpRunner::executeOp(std::unordered_map::iterator &it) + { + auto attr = dynamic_cast(op_attr.get()); + ret = aclnnSoftmaxGetWorkspaceSize(inputTensors[0], aclDataType(attr->dim), outputTensors[0], &workspaceSize, &executor); + + checkRet(ret); + + if (workspaceSize > 0) + { + mallocWorkSpace(workspaceSize); + } + + ret = aclnnSoftmax(workspaceAddr, workspaceSize, executor, aclstream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnSoftmax failed. ERROR: %d\n", name.c_str(), ret); return); + + syncRun(); + + return; + } + + SoftmaxBackwardOpRunner::SoftmaxBackwardOpRunner() : BaseOpRunner("SoftmaxBackward") + { + } + + void SoftmaxBackwardOpRunner::executeOp(std::unordered_map::iterator &it) + { + auto attr = dynamic_cast(op_attr.get()); + ret = aclnnSoftmaxBackwardGetWorkspaceSize(inputTensors[0], inputTensors[1], attr->dim, outputTensors[0], &workspaceSize, &executor); + + checkRet(ret); + + if (workspaceSize > 0) + { + mallocWorkSpace(workspaceSize); + } + + ret = aclnnSoftmaxBackward(workspaceAddr, workspaceSize, executor, aclstream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnSoftmaxBackward failed. ERROR: %d\n", name.c_str(), ret); return); + + syncRun(); + + return; + } + +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/softmax_op_acl.h b/python/jittor/extern/acl/aclops/softmax_op_acl.h new file mode 100644 index 00000000..11af9d36 --- /dev/null +++ b/python/jittor/extern/acl/aclops/softmax_op_acl.h @@ -0,0 +1,27 @@ +#pragma once +#include "utils.h" +#include "base_op.h" + +namespace jittor +{ + class SoftmaxOpRunner : public BaseOpRunner + { + + protected: + void executeOp(std::unordered_map::iterator &it) override; + + public: + SoftmaxOpRunner(); + }; + + class SoftmaxBackwardOpRunner : public BaseOpRunner + { + + protected: + void executeOp(std::unordered_map::iterator &it) override; + + public: + SoftmaxBackwardOpRunner(); + }; + +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/stack_op.py b/python/jittor/extern/acl/aclops/stack_op.py new file mode 100644 index 00000000..c9ba50b3 --- /dev/null +++ b/python/jittor/extern/acl/aclops/stack_op.py @@ -0,0 +1,115 @@ +import os +from jittor_utils import env_or_try_find +import jittor_utils +import ctypes +import glob +import jittor.compiler as compiler +import jittor as jt +import math +import numpy as np + +from typing import Union +from collections.abc import Sequence, Iterable + + +def stack_cmd(name: str, + inputs: list, + output_dtypes: list = None, + output_shapes: list = None, + attr_code: str = "", + attr_header: str = "", + outputs: list = None): + attr_header = "\nnamespace jittor{" + attr_header + "}\n" + + cuda_header = ''' + #include "acl/aclops/aclops.h" + ''' + outputs_ = [] + if outputs is not None: + outputs_ = outputs + else: + assert output_dtypes is not None + assert output_shapes is not None + assert len(output_dtypes) == len(output_shapes) + for i in range(len(output_shapes)): + outputs_.append(jt.empty(output_shapes[i], output_dtypes[i])) + input_code = '' + for i in range(len(inputs)): + input_code += f"op.add(in{i}, true);\n" + + output_code = '' + for i in range(len(outputs_)): + output_code += f"op.add(out{i}, false);\n" + return jt.code(outputs=outputs_, + inputs=inputs, + cuda_header=attr_header + cuda_header, + cuda_src=f""" + + // aclop + {name}OpRunner op; + {input_code} + {output_code} + {attr_code} + op.run();""") + + +class StackACL(jt.Function): + + def __init__(self): + super(StackACL, self).__init__() + + def execute(self, input_tensors, dim): + if type(input_tensors) is tuple: + input_tensors = list(input_tensors) + assert type(input_tensors) is list + assert -1 * len(input_tensors) - 1 <= dim and dim <= len(input_tensors) + for i in range(len(input_tensors)): + if input_tensors[i].dtype != input_tensors[0].dtype: + raise ValueError("All input tensors must have the same dtype") + if input_tensors[i].shape != input_tensors[0].shape: + raise ValueError("All input tensors must have the same shape") + self.input = input_tensors + input_shape = list(input_tensors[0].shape) + output_shape = input_shape[:dim] + [len(input_tensors) + ] + input_shape[dim:] + attr_code = f""" + op.jt_name = "stack"; + ConcatAttr *attr = new ConcatAttr(); + attr->tensorNum = {len(input_tensors)}; + attr->dim = {dim}; + op.op_attr.reset(attr); + """ + self.attr_code = attr_code + result = stack_cmd("Stack", + input_tensors, + output_dtypes=[input_tensors[0].dtype], + output_shapes=[output_shape], + attr_code=self.attr_code)[0] + return result + + def grad(self, grad_output): + grad_inputs = self.split_grad(grad_output, self.input, self.dim) + return grad_inputs + + def split_grad(self, grad_output, input_tensors, axis): + offset = [] + shapeVec = [] + dtypeVec = [] + for tensor in input_tensors: + offset.append(tensor.shape[axis]) + dtypeVec.append(tensor.dtype) + shapeVec.append(tensor.shape) + + attr_code = f""" + op.jt_name = "splitwithsize"; + auto *attr = new SplitWithSizeAttr(); + attr->splitSize = {{ {", ".join(map(str, offset))} }}; + attr->dim = {axis}; + op.op_attr.reset(attr); + """ + + result = stack_cmd("SplitWithSize", [grad_output], + output_dtypes=dtypeVec, + output_shapes=shapeVec, + attr_code=attr_code) + return result diff --git a/python/jittor/extern/acl/aclops/stack_op_acl.cc b/python/jittor/extern/acl/aclops/stack_op_acl.cc new file mode 100644 index 00000000..369582b0 --- /dev/null +++ b/python/jittor/extern/acl/aclops/stack_op_acl.cc @@ -0,0 +1,65 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include "common.h" +#include "op.h" +#include "acl_jittor.h" +#include "ops/random_op.h" +#include "ops/reduce_op.h" +#include "ops/binary_op.h" +#include "ops/broadcast_to_op.h" +#include "ops/transpose_op.h" +#include "ops/array_op.h" +#include "ops/code_op.h" +#include "fused_op.h" +#include "ops/unary_op.h" +#include "ops/ternary_op.h" +#include "executor.h" +#include "misc/cuda_flags.h" +#include "mem/allocator.h" +#include "op_compiler.h" +#include "ops/op_register.h" +#include "opt/tuner_manager.h" +#include "utils/str_utils.h" +#include "aclnn/aclnn.h" +#include "stack_op_acl.h" + +namespace jittor +{ + StackOpRunner::StackOpRunner() : BaseOpRunner("Stack") + { + } + + void StackOpRunner::executeOp(std::unordered_map::iterator &it) + { + auto input_num = in_.size(); + std::vector stackTensorList = {}; + for (int i = 0; i < input_num; i++) + { + stackTensorList.push_back(inputTensors[i]); + } + auto stackTensorListInput = aclCreateTensorList(&stackTensorList[0], input_num); + auto attr = dynamic_cast(op_attr.get()); + ret = aclnnStackGetWorkspaceSize(stackTensorListInput, attr->dim, outputTensors[0], &workspaceSize, &executor); + + checkRet(ret); + + if (workspaceSize > 0) + { + mallocWorkSpace(workspaceSize); + } + + ret = aclnnStack(workspaceAddr, workspaceSize, executor, aclstream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnStack failed. ERROR: %d\n", name.c_str(), ret); return); + + syncRun(); + + return; + } + +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/stack_op_acl.h b/python/jittor/extern/acl/aclops/stack_op_acl.h new file mode 100644 index 00000000..4b7df980 --- /dev/null +++ b/python/jittor/extern/acl/aclops/stack_op_acl.h @@ -0,0 +1,17 @@ +#pragma once +#include "utils.h" +#include "base_op.h" + +namespace jittor +{ + class StackOpRunner : public BaseOpRunner + { + + protected: + void executeOp(std::unordered_map::iterator &it) override; + + public: + StackOpRunner(); + }; + +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/ternary_op_acl.cc b/python/jittor/extern/acl/aclops/ternary_op_acl.cc new file mode 100644 index 00000000..73f8fc4d --- /dev/null +++ b/python/jittor/extern/acl/aclops/ternary_op_acl.cc @@ -0,0 +1,54 @@ +#include +#include +#include +#include +#include +#include +#include +#include "common.h" +#include "op.h" +#include "acl_jittor.h" +#include "ops/random_op.h" +#include "ops/reduce_op.h" +#include "ops/binary_op.h" +#include "ops/broadcast_to_op.h" +#include "ops/transpose_op.h" +#include "ops/array_op.h" +#include "ops/code_op.h" +#include "fused_op.h" +#include "ops/unary_op.h" +#include "ops/ternary_op.h" +#include "executor.h" +#include "misc/cuda_flags.h" +#include "mem/allocator.h" +#include "op_compiler.h" +#include "ops/op_register.h" +#include "opt/tuner_manager.h" +#include "utils/str_utils.h" +#include "aclnn/aclnn.h" +#include "ternary_op_acl.h" + +namespace jittor +{ + TernaryOpRunner::TernaryOpRunner() : BaseOpRunner("ternary") + { + use_nchw = false; + } + + void TernaryOpRunner::executeOp(std::unordered_map::iterator &it) + { + ret = aclnnSWhereGetWorkspaceSize(inputTensors[0], inputTensors[1], inputTensors[2], outputTensors[0], &workspaceSize, &executor); + + checkRet(ret); + + if (workspaceSize > 0) + { + mallocWorkSpace(workspaceSize); + } + + ret = aclnnSWhere(workspaceAddr, workspaceSize, executor, aclstream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnxxx failed. ERROR: %d\n", name.c_str(), ret); return); + // syncRun(); + return; + } +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/ternary_op_acl.h b/python/jittor/extern/acl/aclops/ternary_op_acl.h new file mode 100644 index 00000000..2402c039 --- /dev/null +++ b/python/jittor/extern/acl/aclops/ternary_op_acl.h @@ -0,0 +1,14 @@ +#pragma once +#include "utils.h" +#include "base_op.h" + +namespace jittor +{ + struct TernaryOpRunner : public BaseOpRunner + { + TernaryOpRunner(); + + protected: + void executeOp(std::unordered_map::iterator &it) override; + }; +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/transpose_op.py b/python/jittor/extern/acl/aclops/transpose_op.py new file mode 100644 index 00000000..aa1d7e55 --- /dev/null +++ b/python/jittor/extern/acl/aclops/transpose_op.py @@ -0,0 +1,101 @@ +import os +from jittor_utils import env_or_try_find +import jittor_utils +import ctypes +import glob +import jittor.compiler as compiler +import jittor as jt +import math +import numpy as np + +from typing import Union +from collections.abc import Sequence, Iterable + + +def transpose_cmd(name: str, + inputs: list, + output_dtypes: list = None, + output_shapes: list = None, + attr_code: str = "", + attr_header: str = "", + outputs: list = None): + attr_header = "\nnamespace jittor{" + attr_header + "}\n" + + cuda_header = ''' + #include "acl/aclops/aclops.h" + ''' + outputs_ = [] + if outputs is not None: + outputs_ = outputs + else: + assert output_dtypes is not None + assert output_shapes is not None + assert len(output_dtypes) == len(output_shapes) + for i in range(len(output_shapes)): + outputs_.append(jt.empty(output_shapes[i], output_dtypes[i])) + input_code = '' + for i in range(len(inputs)): + input_code += f"op.add(in{i}, true);\n" + + output_code = '' + for i in range(len(outputs_)): + output_code += f"op.add(out{i}, false);\n" + return jt.code(outputs=outputs_, + inputs=inputs, + cuda_header=attr_header + cuda_header, + cuda_src=f""" + + // aclop + {name}OpRunner op; + {input_code} + {output_code} + {attr_code} + op.run();""") + + +class TransPoseACL(jt.Function): + + def __init__(self): + super(TransPoseACL, self).__init__() + + def execute(self, x, *dim): + self.input = x + if len(dim) == 1 and isinstance(dim[0], Sequence): + dim = dim[0] + elif len(dim) == 2: + axes = list(range(x.ndim)) + a, b = dim + axes[a], axes[b] = axes[b], axes[a] + dim = axes + + attr_code = f""" + op.jt_name = "transpose"; + ReduceAttr *attr = new ReduceAttr(); + attr->axes = {{ {", ".join(map(str, dim))} }}; + op.op_attr.reset(attr); + """ + # calculate output shape + output_shape = [x.shape[i] for i in dim] + output = transpose_cmd("Transpose", [x], + output_dtypes=[x.dtype], + output_shapes=[jt.empty(output_shape).shape], + attr_code=attr_code)[0] + self.dim = dim + return output + + def grad(self, grad_output): + dim = list(range(grad_output.ndim)) + for i, p in enumerate(self.dim): + dim[p] = i + output_shape = [grad_output.shape[i] for i in dim] + attr_code = f""" + op.jt_name = "transpose"; + ReduceAttr *attr = new ReduceAttr(); + attr->axes = {{ {", ".join(map(str, dim))} }}; + op.op_attr.reset(attr); + """ + output = transpose_cmd("Transpose", [grad_output], + output_dtypes=[grad_output.dtype], + output_shapes=[jt.empty(output_shape).shape], + attr_code=attr_code)[0] + return output diff --git a/python/jittor/extern/acl/aclops/transpose_op_acl.cc b/python/jittor/extern/acl/aclops/transpose_op_acl.cc new file mode 100644 index 00000000..721dee20 --- /dev/null +++ b/python/jittor/extern/acl/aclops/transpose_op_acl.cc @@ -0,0 +1,66 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include "common.h" +#include "op.h" +#include "acl_jittor.h" +#include "ops/random_op.h" +#include "ops/reduce_op.h" +#include "ops/binary_op.h" +#include "ops/broadcast_to_op.h" +#include "ops/transpose_op.h" +#include "ops/array_op.h" +#include "ops/code_op.h" +#include "fused_op.h" +#include "ops/unary_op.h" +#include "ops/ternary_op.h" +#include "executor.h" +#include "misc/cuda_flags.h" +#include "mem/allocator.h" +#include "op_compiler.h" +#include "ops/op_register.h" +#include "opt/tuner_manager.h" +#include "utils/str_utils.h" +#include "aclnn/aclnn.h" +#include "transpose_op_acl.h" + +namespace jittor +{ + TransposeOpRunner::TransposeOpRunner() : BaseOpRunner("Transpose") + { + } + + void TransposeOpRunner::executeOp(std::unordered_map::iterator &it) + { + auto attr = dynamic_cast(op_attr.get()); + + aclIntArray *dim = nullptr; + + dim = aclCreateIntArray(attr->axes.data(), attr->axes.size()); + bool keepdims = attr->keepdims; + + ret = aclnnPermuteGetWorkspaceSize(inputTensors[0], dim, outputTensors[0], &workspaceSize, &executor); + + checkRet(ret); + + if (workspaceSize > 0) + { + mallocWorkSpace(workspaceSize); + } + + ret = aclnnPermute(workspaceAddr, workspaceSize, executor, aclstream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnPermute failed. ERROR: %d\n", name.c_str(), ret); return); + + syncRun(); + + aclDestroyIntArray(dim); + + return; + } + +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/transpose_op_acl.h b/python/jittor/extern/acl/aclops/transpose_op_acl.h new file mode 100644 index 00000000..737fffd8 --- /dev/null +++ b/python/jittor/extern/acl/aclops/transpose_op_acl.h @@ -0,0 +1,17 @@ +#pragma once +#include "utils.h" +#include "base_op.h" + +namespace jittor +{ + class TransposeOpRunner : public BaseOpRunner + { + + protected: + void executeOp(std::unordered_map::iterator &it) override; + + public: + TransposeOpRunner(); + }; + +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/triu_op.py b/python/jittor/extern/acl/aclops/triu_op.py new file mode 100644 index 00000000..f5309c9d --- /dev/null +++ b/python/jittor/extern/acl/aclops/triu_op.py @@ -0,0 +1,74 @@ +import os +from jittor_utils import env_or_try_find +import jittor_utils +import ctypes +import glob +import jittor.compiler as compiler +import jittor as jt +import math +import numpy as np + +from typing import Union +from collections.abc import Sequence, Iterable + +def triu_cmd(name: str, + inputs: list, + output_dtypes: list = None, + output_shapes: list = None, + attr_code: str = "", + attr_header: str = "", + outputs: list = None): + attr_header = "\nnamespace jittor{" + attr_header + "}\n" + + cuda_header = ''' + #include "acl/aclops/aclops.h" + ''' + outputs_ = [] + if outputs is not None: + outputs_ = outputs + else: + assert output_dtypes is not None + assert output_shapes is not None + assert len(output_dtypes) == len(output_shapes) + for i in range(len(output_shapes)): + outputs_.append(jt.empty(output_shapes[i], output_dtypes[i])) + input_code = '' + for i in range(len(inputs)): + input_code += f"op.add(in{i}, true);\n" + + output_code = '' + for i in range(len(outputs_)): + output_code += f"op.add(out{i}, false);\n" + return jt.code(outputs=outputs_, + inputs=inputs, + cuda_header=attr_header + cuda_header, + cuda_src=f""" + + // aclop + {name}OpRunner op; + {input_code} + {output_code} + {attr_code} + op.run();""") + +class TriuACL(jt.Function): + + def __init__(self): + super(TriuACL, self).__init__() + + def execute(self, input, diagonal): + attr_code = f""" + op.jt_name = "triu"; + TriuAttr *attr = new TriuAttr(); + attr->diagonal = {diagonal}; + op.op_attr.reset(attr); + """ + + result = triu_cmd("Triu", [input], + output_dtypes=[input.dtype], + output_shapes=[input.shape], + attr_code=attr_code)[0] + return result + + def grad(self, grad_output): + return grad_output diff --git a/python/jittor/extern/acl/aclops/triu_op_acl.cc b/python/jittor/extern/acl/aclops/triu_op_acl.cc new file mode 100644 index 00000000..8da66090 --- /dev/null +++ b/python/jittor/extern/acl/aclops/triu_op_acl.cc @@ -0,0 +1,58 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include "common.h" +#include "op.h" +#include "acl_jittor.h" +#include "ops/random_op.h" +#include "ops/reduce_op.h" +#include "ops/binary_op.h" +#include "ops/broadcast_to_op.h" +#include "ops/transpose_op.h" +#include "ops/array_op.h" +#include "ops/code_op.h" +#include "fused_op.h" +#include "ops/unary_op.h" +#include "ops/ternary_op.h" +#include "executor.h" +#include "misc/cuda_flags.h" +#include "mem/allocator.h" +#include "op_compiler.h" +#include "ops/op_register.h" +#include "opt/tuner_manager.h" +#include "utils/str_utils.h" +#include "aclnn/aclnn.h" +#include "triu_op_acl.h" + +namespace jittor +{ + TriuOpRunner::TriuOpRunner() : BaseOpRunner("Triu") + { + } + + void TriuOpRunner::executeOp(std::unordered_map::iterator &it) + { + auto attr = dynamic_cast(op_attr.get()); + ret = aclnnTriuGetWorkspaceSize(inputTensors[0], aclDataType(attr->diagonal), outputTensors[0], &workspaceSize, &executor); + + checkRet(ret); + + if (workspaceSize > 0) + { + mallocWorkSpace(workspaceSize); + } + + ret = aclnnTriu(workspaceAddr, workspaceSize, executor, aclstream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnTriu failed. ERROR: %d\n", name.c_str(), ret); return); + + syncRun(); + + return; + } + +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/triu_op_acl.h b/python/jittor/extern/acl/aclops/triu_op_acl.h new file mode 100644 index 00000000..8ffb3502 --- /dev/null +++ b/python/jittor/extern/acl/aclops/triu_op_acl.h @@ -0,0 +1,16 @@ +#pragma once +#include "utils.h" +#include "base_op.h" + +namespace jittor +{ + class TriuOpRunner : public BaseOpRunner + { + + protected: + void executeOp(std::unordered_map::iterator &it) override; + public: + TriuOpRunner(); + }; + +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/unary_op_acl.cc b/python/jittor/extern/acl/aclops/unary_op_acl.cc new file mode 100644 index 00000000..d1172fce --- /dev/null +++ b/python/jittor/extern/acl/aclops/unary_op_acl.cc @@ -0,0 +1,59 @@ +#include +#include +#include +#include +#include +#include +#include +#include "common.h" +#include "op.h" +#include "acl_jittor.h" +#include "ops/random_op.h" +#include "ops/reduce_op.h" +#include "ops/binary_op.h" +#include "ops/broadcast_to_op.h" +#include "ops/transpose_op.h" +#include "ops/array_op.h" +#include "ops/code_op.h" +#include "fused_op.h" +#include "ops/unary_op.h" +#include "ops/ternary_op.h" +#include "executor.h" +#include "misc/cuda_flags.h" +#include "mem/allocator.h" +#include "op_compiler.h" +#include "ops/op_register.h" +#include "opt/tuner_manager.h" +#include "utils/str_utils.h" +#include "aclnn/aclnn.h" + +#include "unary_op_acl.h" + +namespace jittor +{ + UnaryOpRunner::UnaryOpRunner() : BaseOpRunner("unary") + { + use_nchw = false; + is_group_op = true; + } + + void UnaryOpRunner::executeOp(std::unordered_map::iterator &it) + { + if (name == "Cast") + ret = it->second.getWorkspaceSizeFuncCast(inputTensors[0], get_dtype(out_[0]->dtype()), outputTensors[0], &workspaceSize, &executor); + else + ret = it->second.getWorkspaceSizeFuncUnaryNonzero(inputTensors[0], outputTensors[0], &workspaceSize, &executor); + + checkRet(ret); + + if (workspaceSize > 0) + { + mallocWorkSpace(workspaceSize); + } + + ret = it->second.executeFunc(workspaceAddr, workspaceSize, executor, aclstream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnxxx failed. ERROR: %d\n", name.c_str(), ret); return); + // syncRun(); + return; + } +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/unary_op_acl.h b/python/jittor/extern/acl/aclops/unary_op_acl.h new file mode 100644 index 00000000..b4a0ab9c --- /dev/null +++ b/python/jittor/extern/acl/aclops/unary_op_acl.h @@ -0,0 +1,14 @@ +#pragma once +#include "utils.h" +#include "base_op.h" + +namespace jittor +{ + struct UnaryOpRunner : public BaseOpRunner + { + UnaryOpRunner(); + + protected: + void executeOp(std::unordered_map::iterator &it) override; + }; +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/utils.cc b/python/jittor/extern/acl/aclops/utils.cc new file mode 100644 index 00000000..1aac88db --- /dev/null +++ b/python/jittor/extern/acl/aclops/utils.cc @@ -0,0 +1,152 @@ +#include +#include +#include +#include +#include +#include +#include "utils.h" + +namespace jittor +{ + aclDataType get_dtype(NanoString s) + { + switch (s.data) + { + case 22667: + return ACL_FLOAT; + case 14474: + return ACL_FLOAT16; + case 27781: + return ACL_INT64; + case 19588: + return ACL_INT32; + case 3202: + return ACL_INT8; + case 11395: + return ACL_INT16; + case 3206: + return ACL_UINT8; + case 11399: + return ACL_UINT16; + case 19592: + return ACL_UINT32; + case 3713: + return ACL_BOOL; + default: + LOGf << "Not supported dtype: " << s; + return ACL_FLOAT; // 默认返回 ACL_FLOAT + } + } + + std::unordered_map op_idx_map = + { + {"Add", 1}, + {"Sub", 2}, + {"Expand", 3}, + {"Cast", 4}, + {"Unary", 5}, + {"Binary", 6}, + {"BatchMatMul", 7}, + {"MatMul", 8}, + {"ReduceSum", 9}, + {"ReduceMean", 10}, + {"ReduceMax", 11}, + {"ReduceMin", 12}, + {"RandomUniform", 13}, + {"RandomNormal", 14}, + {"Nonzero", 15}, + {"Select", 16}, + {"Where", 17}, + {"Triu", 18}, + {"Transpose", 19}, + {"Conv2d", 20}, + {"Conv2dBackward", 21}, + {"Maxpool", 22}, + {"MaxpoolBackward", 23}, + {"Avgpool", 24}, + {"AvgpoolBackward", 25}, + {"Flip", 26}, + {"Concat", 27}, + {"Gather", 28}, + {"Cumsum", 29}, + {"Scatter", 30}, + {"Floor", 31}, + {"Index", 32}, + {"SliceV2", 33}, + {"IndexPutImpl", 34}, + {"IndexPutImplAccumulate", 35}, + {"StridedSliceAssignV2", 36}, + {"Range", 37}, + {"LeakyReLU", 38}, + {"LeakyReLUBackward", 39}, + {"Dropout", 40}, + {"DropoutBackward", 41}, + {"SiLU", 42}, + {"SiLUBackward", 43}, + {"Sigmoid", 44}, + {"SigmoidBackward", 45}, + {"Embedding", 46}, + {"EmbeddingBackward", 47}, + {"InplaceMaskedScatter", 48}, + {"MaskedSelect", 49}, + {"SplitWithSize", 50}, + {"FlashAttention", 51}, + {"FlashAttentionBackward", 52}, + {"Softmax", 53}, + {"SoftmaxBackward", 54}, + {"BatchNorm", 55}, + {"BatchNormBackward", 56}, + {"LayerNorm", 57}, + {"RotaryPosEmb", 58}, + {"Stack", 59}, + {"NanToNum", 60}, + }; + + int CreateAclTensor(const std::vector &shape, void *deviceAddr, int64_t size, + aclDataType dataType, aclTensor **tensor, bool use_nchw) + { + // 计算连续tensor的strides + std::vector strides(shape.size(), 1); + for (int64_t i = shape.size() - 2; i >= 0; i--) + { + strides[i] = shape[i + 1] * strides[i + 1]; + } + if (shape.size() == 0) + strides = {}; + // 调用aclCreateTensor接口创建aclTensor + if (use_nchw) + *tensor = aclCreateTensor(shape.data(), shape.size(), dataType, strides.data(), 0, aclFormat::ACL_FORMAT_NCHW, + shape.data(), shape.size(), deviceAddr); + else + *tensor = aclCreateTensor(shape.data(), shape.size(), dataType, strides.data(), 0, aclFormat::ACL_FORMAT_ND, + shape.data(), shape.size(), deviceAddr); + return 0; + } + + int CreateFakeTransAclTensor(std::vector &shape, void *deviceAddr, int64_t size, + aclDataType dataType, aclTensor **tensor, bool use_nchw) + { + // 计算连续tensor的strides + std::vector strides(shape.size(), 1); + for (int64_t i = shape.size() - 2; i >= 0; i--) + { + strides[i] = shape[i + 1] * strides[i + 1]; + } + if (shape.size() == 0) + strides = {}; + int n = shape.size(); + if (n > 1) + { + std::swap(shape[n - 1], shape[n - 2]); + std::swap(strides[n - 1], strides[n - 2]); + } + // 调用aclCreateTensor接口创建aclTensor + if (use_nchw) + *tensor = aclCreateTensor(shape.data(), shape.size(), dataType, strides.data(), 0, aclFormat::ACL_FORMAT_NCHW, + shape.data(), shape.size(), deviceAddr); + else + *tensor = aclCreateTensor(shape.data(), shape.size(), dataType, strides.data(), 0, aclFormat::ACL_FORMAT_ND, + shape.data(), shape.size(), deviceAddr); + return 0; + } +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/utils.h b/python/jittor/extern/acl/aclops/utils.h new file mode 100644 index 00000000..de2b7bc7 --- /dev/null +++ b/python/jittor/extern/acl/aclops/utils.h @@ -0,0 +1,20 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include "misc/nano_string.h" + +namespace jittor +{ + aclDataType get_dtype(NanoString s); + + extern std::unordered_map op_idx_map; + int CreateAclTensor(const std::vector &shape, void *deviceAddr, int64_t size, + aclDataType dataType, aclTensor **tensor, bool use_nchw = false); + + int CreateFakeTransAclTensor(std::vector &shape, void *deviceAddr, int64_t size, + aclDataType dataType, aclTensor **tensor, bool use_nchw = false); +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/where_op.py b/python/jittor/extern/acl/aclops/where_op.py new file mode 100644 index 00000000..f0417bb9 --- /dev/null +++ b/python/jittor/extern/acl/aclops/where_op.py @@ -0,0 +1,129 @@ +import os +from jittor_utils import env_or_try_find +import jittor_utils +import ctypes +import glob +import jittor.compiler as compiler +import jittor as jt +import math +import numpy as np + +from typing import Union +from collections.abc import Sequence, Iterable + + +def where_cmd(name: str, + inputs: list, + output_dtypes: list = None, + output_shapes: list = None, + attr_code: str = "", + attr_header: str = "", + outputs: list = None): + attr_header = "\nnamespace jittor{" + attr_header + "}\n" + + cuda_header = ''' + #include "acl/aclops/aclops.h" + ''' + outputs_ = [] + if outputs is not None: + outputs_ = outputs + else: + assert output_dtypes is not None + assert output_shapes is not None + assert len(output_dtypes) == len(output_shapes) + for i in range(len(output_shapes)): + outputs_.append(jt.empty(output_shapes[i], output_dtypes[i])) + input_code = '' + for i in range(len(inputs)): + input_code += f"op.add(in{i}, true);\n" + + output_code = '' + for i in range(len(outputs_)): + output_code += f"op.add(out{i}, false);\n" + return jt.code(outputs=outputs_, + inputs=inputs, + cuda_header=attr_header + cuda_header, + cuda_src=f""" + + // aclop + {name}OpRunner op; + {input_code} + {output_code} + {attr_code} + op.run();""") + + +class NonzeroACL(jt.Function): + + def __init__(self): + super(NonzeroACL, self).__init__() + + def execute(self, x): + attr_code = f""" + op.jt_name = "nonzero"; + """ + nonzero_cnt = (x != 0.0).sum().item() + + result = where_cmd("Nonzero", [x], + output_dtypes=['int64'], + output_shapes=[(nonzero_cnt, x.ndim)], + attr_code=attr_code)[0] + + return result + + def grad(self, grad_output): + return grad_output + + +class WhereACL(jt.Function): + + def __init__(self): + super(WhereACL, self).__init__() + + def execute(self, condition, x=None, y=None): + # case 1 (unary) + if y is None: + self.unary = True + + # In this case, `condition` is the input, while `x` is dtype + result = NonzeroACL()(condition).t() + result = [result[i] for i in range(result.size(0))] + return result + # The return value should be a tuple, but even we set to tuple here, it will be convert to a list in `Function.__call__`. + + # case 2 (cond ? x : y) + else: + self.condition = condition + + if x.dtype != y.dtype: + if x.dtype == jt.float32: + y = y.float32() + elif y.dtype == jt.float32: + x = x.float32() + else: + x = x.to(y.dtype) + + self.x = x + self.y = y + + result = where_cmd("Where", [condition, x, y], + output_dtypes=[x.dtype], + output_shapes=[x.shape], + attr_code="op.jt_name=\"where\";")[0] + return result + + def grad(self, grad_output): + if hasattr(self, 'unary') and self.unary: + return grad_output + else: + tmp = jt.zeros(grad_output.shape, dtype=grad_output.dtype) + grad_x = where_cmd("Where", [self.condition, grad_output, tmp], + output_dtypes=[self.x.dtype], + output_shapes=[self.x.shape], + attr_code="op.jt_name=\"where\";")[0] + + grad_y = where_cmd("Where", [self.condition, tmp, grad_output], + output_dtypes=[self.y.dtype], + output_shapes=[self.y.shape], + attr_code="op.jt_name=\"where\";")[0] + return grad_output, grad_x, grad_y diff --git a/python/jittor/extern/acl/aclops/where_op_acl.cc b/python/jittor/extern/acl/aclops/where_op_acl.cc new file mode 100644 index 00000000..a1d2ddc8 --- /dev/null +++ b/python/jittor/extern/acl/aclops/where_op_acl.cc @@ -0,0 +1,78 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include "common.h" +#include "op.h" +#include "acl_jittor.h" +#include "ops/random_op.h" +#include "ops/reduce_op.h" +#include "ops/binary_op.h" +#include "ops/broadcast_to_op.h" +#include "ops/transpose_op.h" +#include "ops/array_op.h" +#include "ops/code_op.h" +#include "fused_op.h" +#include "ops/unary_op.h" +#include "ops/ternary_op.h" +#include "executor.h" +#include "misc/cuda_flags.h" +#include "mem/allocator.h" +#include "op_compiler.h" +#include "ops/op_register.h" +#include "opt/tuner_manager.h" +#include "utils/str_utils.h" +#include "aclnn/aclnn.h" +#include "where_op_acl.h" + +namespace jittor +{ + WhereOpRunner::WhereOpRunner() : BaseOpRunner("Where") + { + } + + void WhereOpRunner::executeOp(std::unordered_map::iterator &it) + { + ret = aclnnSWhereGetWorkspaceSize(inputTensors[0], inputTensors[1], inputTensors[2], outputTensors[0], &workspaceSize, &executor); + + checkRet(ret); + + if (workspaceSize > 0) + { + mallocWorkSpace(workspaceSize); + } + + ret = aclnnSWhere(workspaceAddr, workspaceSize, executor, aclstream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnSWhere failed. ERROR: %d\n", name.c_str(), ret); return); + + syncRun(); + return; + } + + NonzeroOpRunner::NonzeroOpRunner() : BaseOpRunner("Nonzero") + { + } + + void NonzeroOpRunner::executeOp(std::unordered_map::iterator &it) + { + ret = aclnnNonzeroGetWorkspaceSize(inputTensors[0], outputTensors[0], &workspaceSize, &executor); + + checkRet(ret); + + if (workspaceSize > 0) + { + mallocWorkSpace(workspaceSize); + } + + ret = aclnnNonzero(workspaceAddr, workspaceSize, executor, aclstream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnNonzero failed. ERROR: %d\n", name.c_str(), ret); return); + + syncRun(); + return; + } + +} \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/where_op_acl.h b/python/jittor/extern/acl/aclops/where_op_acl.h new file mode 100644 index 00000000..d881f752 --- /dev/null +++ b/python/jittor/extern/acl/aclops/where_op_acl.h @@ -0,0 +1,26 @@ +#pragma once +#include "utils.h" +#include "base_op.h" + +namespace jittor +{ + class WhereOpRunner : public BaseOpRunner + { + + protected: + void executeOp(std::unordered_map::iterator &it) override; + + public: + WhereOpRunner(); + }; + + class NonzeroOpRunner : public BaseOpRunner + { + + protected: + void executeOp(std::unordered_map::iterator &it) override; + + public: + NonzeroOpRunner(); + }; +} \ No newline at end of file diff --git a/python/jittor/extern/acl/hccl/inc/hccl_wrapper.h b/python/jittor/extern/acl/hccl/inc/hccl_wrapper.h new file mode 100644 index 00000000..51202e92 --- /dev/null +++ b/python/jittor/extern/acl/hccl/inc/hccl_wrapper.h @@ -0,0 +1,38 @@ +// *************************************************************** +// Copyright (c) 2025 Jittor. +// All Rights Reserved. +// Maintainers: +// Jiapeng Zhang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** + +#pragma once +#include "mpi_wrapper.h" + +#define ACLCHECK(ret) do {\ + if(ret != ACL_SUCCESS)\ + {\ + LOGe << "retcode: " << ret;\ + return;\ + }\ +} while(0)\ + +#define HCCLCHECK(ret) do {\ + if(ret != HCCL_SUCCESS)\ + {\ + LOGe << HcclGetErrorString(ret) << " retcode: " << ret;\ + return;\ + }\ +} while(0) + +#include + +namespace jittor { + + EXTERN_LIB HcclRootInfo root_info; + EXTERN_LIB HcclComm comm; + EXTERN_LIB uint32_t hccl_device_id; + +} // jittor diff --git a/python/jittor/extern/acl/hccl/ops/hccl_all_gather_op.cc b/python/jittor/extern/acl/hccl/ops/hccl_all_gather_op.cc new file mode 100644 index 00000000..d1d76879 --- /dev/null +++ b/python/jittor/extern/acl/hccl/ops/hccl_all_gather_op.cc @@ -0,0 +1,70 @@ +// *************************************************************** +// Copyright (c) 2025 Jittor. +// All Rights Reserved. +// Maintainers: +// Jiapeng Zhang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** + +#include "var.h" +#include "hccl_all_gather_op.h" +#include "ops/op_register.h" +#include "utils/str_utils.h" +#include "hccl_wrapper.h" + +namespace jittor { + +#ifndef JIT + +static auto hccl_all_gather = + get_op_info("hccl_all_gather").get_constructor(); + +HcclAllGatherOp::HcclAllGatherOp(Var* x) : x(x) { + flags.set(NodeFlags::_cpu, 0); + flags.set(NodeFlags::_cuda, 1); + y = create_output(nullptr, x->dtype()); +} + +void HcclAllGatherOp::infer_shape() { + NanoVector yshape; + yshape.push_back(mpi_world_size * x->shape[0]); + for (int i=1; ishape.size(); i++) + yshape.push_back(x->shape[i]); + y->set_shape(yshape); +} + +VarPtr HcclAllGatherOp::grad(Var* out, Var* dout, Var* v, int v_index) { + LOGf << "not implemented"; + return nullptr; +} + +void HcclAllGatherOp::jit_prepare(JK& jk) { + jk << "«Tx:" << x->dtype(); +} + +#else // JIT + +void HcclAllGatherOp::jit_run() { + LOGir << "HcclAllGatherOp::jit_run"; + @define(T_HCCL, + @if(@strcmp(@Tx,float)==0 || @strcmp(@Tx,float32)==0, HcclDataType::HCCL_DATA_TYPE_FP32) + @if(@strcmp(@Tx,int)==0 || @strcmp(@Tx,int32)==0, HcclDataType::HCCL_DATA_TYPE_INT32) + @if(@strcmp(@Tx,float64)==0, HcclDataType::HCCL_DATA_TYPE_FP64) + @if(@strcmp(@Tx,int64)==0, HcclDataType::HCCL_DATA_TYPE_INT64) + @if(@strcmp(@Tx,uint8)==0, HcclDataType::HCCL_DATA_TYPE_UINT8) + @if(@strcmp(@Tx,float16)==0, HcclDataType::HCCL_DATA_TYPE_FP16) + ) + auto* __restrict__ xp = x->ptr(); + auto* __restrict__ yp = y->ptr(); + ACLCHECK(aclrtSynchronizeDevice()); + ACLCHECK(aclrtSynchronizeStream(aclstream)); + HCCLCHECK(HcclAllGather(xp, yp, (uint64_t)x->num, @T_HCCL, comm, aclstream)); + ACLCHECK(aclrtSynchronizeDevice()); + ACLCHECK(aclrtSynchronizeStream(aclstream)); +} + +#endif // JIT + +} // jittor diff --git a/python/jittor/extern/acl/hccl/ops/hccl_all_gather_op.h b/python/jittor/extern/acl/hccl/ops/hccl_all_gather_op.h new file mode 100644 index 00000000..e3b7caeb --- /dev/null +++ b/python/jittor/extern/acl/hccl/ops/hccl_all_gather_op.h @@ -0,0 +1,26 @@ +// *************************************************************** +// Copyright (c) 2025 Jittor. +// All Rights Reserved. +// Maintainers: +// Jiapeng Zhang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + +namespace jittor { + +struct HcclAllGatherOp : Op { + Var* x, * y; + + HcclAllGatherOp(Var* x); + void infer_shape() override; + + const char* name() const override { return "hccl_all_gather"; } + VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; + DECLARE_jit_run; +}; + +} // jittor \ No newline at end of file diff --git a/python/jittor/extern/acl/hccl/ops/hccl_all_reduce_op.cc b/python/jittor/extern/acl/hccl/ops/hccl_all_reduce_op.cc new file mode 100644 index 00000000..8910c09a --- /dev/null +++ b/python/jittor/extern/acl/hccl/ops/hccl_all_reduce_op.cc @@ -0,0 +1,62 @@ +#include "var.h" +#include "hccl_all_reduce_op.h" +#include "ops/op_register.h" +#include "utils/str_utils.h" +#include "hccl_wrapper.h" + +namespace jittor { + +#ifndef JIT + +static auto hccl_all_reduce = + get_op_info("hccl_all_reduce").get_constructor(); + +HcclAllReduceOp::HcclAllReduceOp(Var* x, string reduce_op) : x(x), reduce_op(reduce_op) { + flags.set(NodeFlags::_cpu, 0); + flags.set(NodeFlags::_cuda, 1); + y = create_output(nullptr, x->dtype()); +} + +void HcclAllReduceOp::infer_shape() { + y->set_shape(x->shape); +} + +VarPtr HcclAllReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) { + return hccl_all_reduce(dout, reduce_op); +} + +void HcclAllReduceOp::jit_prepare(JK& jk) { + jk << "«Tx:" << x->dtype(); + jk << "«Op:" << reduce_op; +} + +#else // JIT + +void HcclAllReduceOp::jit_run() { + //LOGir << "HcclAllReduceOp::jit_run"; + @define(T_HCCL, + @if(@strcmp(@Tx,float)==0 || @strcmp(@Tx,float32)==0, HcclDataType::HCCL_DATA_TYPE_FP32) + @if(@strcmp(@Tx,int)==0 || @strcmp(@Tx,int32)==0, HcclDataType::HCCL_DATA_TYPE_INT32) + @if(@strcmp(@Tx,float64)==0, HcclDataType::HCCL_DATA_TYPE_FP64) + @if(@strcmp(@Tx,int64)==0, HcclDataType::HCCL_DATA_TYPE_INT64) + @if(@strcmp(@Tx,uint8)==0, HcclDataType::HCCL_DATA_TYPE_UINT8) + @if(@strcmp(@Tx,float16)==0, HcclDataType::HCCL_DATA_TYPE_FP16) + ) + @define(REDUCE_OP, + @if(@strcmp(@Op,sum)==0, HcclReduceOp::HCCL_REDUCE_SUM) + @if(@strcmp(@Op,prod)==0, HcclReduceOp::HCCL_REDUCE_PROD) + @if(@strcmp(@Op,max)==0, HcclReduceOp::HCCL_REDUCE_MAX) + @if(@strcmp(@Op,min)==0, HcclReduceOp::HCCL_REDUCE_MIN) + ) + auto* __restrict__ xp = x->ptr(); + auto* __restrict__ yp = y->ptr(); + ACLCHECK(aclrtSynchronizeDevice()); + ACLCHECK(aclrtSynchronizeStream(aclstream)); + HCCLCHECK(HcclAllReduce(xp, yp, (uint64_t)x->num, @T_HCCL, @REDUCE_OP, comm, aclstream)); + ACLCHECK(aclrtSynchronizeDevice()); + ACLCHECK(aclrtSynchronizeStream(aclstream)); +} + +#endif // JIT + +} // jittor diff --git a/python/jittor/extern/acl/hccl/ops/hccl_all_reduce_op.h b/python/jittor/extern/acl/hccl/ops/hccl_all_reduce_op.h new file mode 100644 index 00000000..2b706f21 --- /dev/null +++ b/python/jittor/extern/acl/hccl/ops/hccl_all_reduce_op.h @@ -0,0 +1,18 @@ +#pragma once +#include "op.h" + +namespace jittor { + +struct HcclAllReduceOp : Op { + Var* x, * y; + string reduce_op; + + HcclAllReduceOp(Var* x, string reduce_op="sum"); + void infer_shape() override; + + const char* name() const override { return "hccl_all_reduce"; } + VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; + DECLARE_jit_run; +}; + +} // jittor diff --git a/python/jittor/extern/acl/hccl/ops/hccl_broadcast_op.cc b/python/jittor/extern/acl/hccl/ops/hccl_broadcast_op.cc new file mode 100644 index 00000000..bb1e3033 --- /dev/null +++ b/python/jittor/extern/acl/hccl/ops/hccl_broadcast_op.cc @@ -0,0 +1,63 @@ +#include "var.h" +#include "hccl_broadcast_op.h" +#include "ops/op_register.h" +#include "utils/str_utils.h" +#include "hccl_wrapper.h" +#include + +namespace jittor { + +#ifndef JIT + +static auto hccl_broadcast = + get_op_info("hccl_broadcast").get_constructor(); + +HcclBroadcastOp::HcclBroadcastOp(Var* x, int root) : x(x), root(root) { + flags.set(NodeFlags::_cpu, 0); + flags.set(NodeFlags::_cuda, 1); + y = create_output(nullptr, x->dtype()); +} + +void HcclBroadcastOp::infer_shape() { + y->set_shape(x->shape); +} + +VarPtr HcclBroadcastOp::grad(Var* out, Var* dout, Var* v, int v_index) { + return hccl_broadcast(dout, root); +} + +void HcclBroadcastOp::jit_prepare(JK& jk) { + jk << "«Tx:" << x->dtype(); + jk << "«Root:" << root; +} + +#else // JIT + +void HcclBroadcastOp::jit_run() { + //LOGir << "HcclBroadcastOp::jit_run"; + @define(T_HCCL, + @if(@strcmp(@Tx,float)==0 || @strcmp(@Tx,float32)==0, HcclDataType::HCCL_DATA_TYPE_FP32) + @if(@strcmp(@Tx,int)==0 || @strcmp(@Tx,int32)==0, HcclDataType::HCCL_DATA_TYPE_INT32) + @if(@strcmp(@Tx,float64)==0, HcclDataType::HCCL_DATA_TYPE_FP64) + @if(@strcmp(@Tx,int64)==0, HcclDataType::HCCL_DATA_TYPE_INT64) + @if(@strcmp(@Tx,uint8)==0, HcclDataType::HCCL_DATA_TYPE_UINT8) + @if(@strcmp(@Tx,float16)==0, HcclDataType::HCCL_DATA_TYPE_FP16) + ) + auto* __restrict__ xp = x->ptr(); + auto* __restrict__ yp = y->ptr(); + //LOGir << "HcclBroadcastOp::jit_run " << @Root << " " << hccl_device_id << " " << xp << " " << yp; + //ACLCHECK(aclrtSynchronizeStream(aclstream)); + ACLCHECK(aclrtSynchronizeDevice()); + ACLCHECK(aclrtSynchronizeStream(aclstream)); + HCCLCHECK(HcclBroadcast(@Root == hccl_device_id ? xp : yp, (uint64_t)x->num, @T_HCCL, @Root, comm, aclstream)); + if (@Root == hccl_device_id) { + ACLCHECK(aclrtMemcpy(yp, x->num * sizeof(Tx), xp, x->num * sizeof(Tx), ACL_MEMCPY_DEVICE_TO_DEVICE)); + ACLCHECK(aclrtSynchronizeDevice()); + } + ACLCHECK(aclrtSynchronizeDevice()); + ACLCHECK(aclrtSynchronizeStream(aclstream)); +} + +#endif // JIT + +} // jittor diff --git a/python/jittor/extern/acl/hccl/ops/hccl_broadcast_op.h b/python/jittor/extern/acl/hccl/ops/hccl_broadcast_op.h new file mode 100644 index 00000000..17f4e16a --- /dev/null +++ b/python/jittor/extern/acl/hccl/ops/hccl_broadcast_op.h @@ -0,0 +1,18 @@ +#pragma once +#include "op.h" + +namespace jittor { + +struct HcclBroadcastOp : Op { + Var* x, * y; + int root; + + HcclBroadcastOp(Var* x, int root=0); + void infer_shape() override; + + const char* name() const override { return "hccl_broadcast"; } + VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; + DECLARE_jit_run; +}; + +} // jittor diff --git a/python/jittor/extern/acl/hccl/ops/hccl_reduce_op.cc b/python/jittor/extern/acl/hccl/ops/hccl_reduce_op.cc new file mode 100644 index 00000000..4e952d6a --- /dev/null +++ b/python/jittor/extern/acl/hccl/ops/hccl_reduce_op.cc @@ -0,0 +1,63 @@ +#include "var.h" +#include "hccl_reduce_op.h" +#include "ops/op_register.h" +#include "utils/str_utils.h" +#include "hccl_wrapper.h" + +namespace jittor { + +#ifndef JIT + +HcclReduceOp::HcclReduceOp(Var* x, string reduce_op, int root) : x(x), reduce_op(reduce_op), root(root) { + flags.set(NodeFlags::_cpu, 0); + flags.set(NodeFlags::_cuda, 1); + y = create_output(nullptr, x->dtype()); +} + +void HcclReduceOp::infer_shape() { + y->set_shape(x->shape); +} + +VarPtr HcclReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) { + static auto hccl_broadcast = + get_op_info("hccl_broadcast").get_constructor(); + return hccl_broadcast(dout, root); +} + +void HcclReduceOp::jit_prepare(JK& jk) { + jk << "«Tx:" << x->dtype(); + jk << "«Op:" << reduce_op; + jk << "«Root:" << root; +} + +#else // JIT + +void HcclReduceOp::jit_run() { + LOGir << "HcclReduceOp::jit_run"; + @define(T_HCCL, + @if(@strcmp(@Tx,float)==0 || @strcmp(@Tx,float32)==0, HcclDataType::HCCL_DATA_TYPE_FP32) + @if(@strcmp(@Tx,int)==0 || @strcmp(@Tx,int32)==0, HcclDataType::HCCL_DATA_TYPE_INT32) + @if(@strcmp(@Tx,float64)==0, HcclDataType::HCCL_DATA_TYPE_FP64) + @if(@strcmp(@Tx,int64)==0, HcclDataType::HCCL_DATA_TYPE_INT64) + @if(@strcmp(@Tx,uint8)==0, HcclDataType::HCCL_DATA_TYPE_UINT8) + @if(@strcmp(@Tx,float16)==0, HcclDataType::HCCL_DATA_TYPE_FP16) + ) + @define(REDUCE_OP, + @if(@strcmp(@Op,sum)==0, HcclReduceOp::HCCL_REDUCE_SUM) + @if(@strcmp(@Op,prod)==0, HcclReduceOp::HCCL_REDUCE_PROD) + @if(@strcmp(@Op,max)==0, HcclReduceOp::HCCL_REDUCE_MAX) + @if(@strcmp(@Op,min)==0, HcclReduceOp::HCCL_REDUCE_MIN) + ) + auto* __restrict__ xp = x->ptr(); + auto* __restrict__ yp = y->ptr(); + + ACLCHECK(aclrtSynchronizeDevice()); + ACLCHECK(aclrtSynchronizeStream(aclstream)); + HCCLCHECK(HcclReduce(xp, yp, (uint64_t)x->num, @T_HCCL, @REDUCE_OP, @Root, comm, aclstream)); + ACLCHECK(aclrtSynchronizeDevice()); + ACLCHECK(aclrtSynchronizeStream(aclstream)); +} + +#endif // JIT + +} // jittor diff --git a/python/jittor/extern/acl/hccl/ops/hccl_reduce_op.h b/python/jittor/extern/acl/hccl/ops/hccl_reduce_op.h new file mode 100644 index 00000000..2ae276c9 --- /dev/null +++ b/python/jittor/extern/acl/hccl/ops/hccl_reduce_op.h @@ -0,0 +1,19 @@ +#pragma once +#include "op.h" + +namespace jittor { + +struct HcclReduceOp : Op { + Var* x, * y; + string reduce_op; + int root; + + HcclReduceOp(Var* x, string reduce_op="sum", int root=0); + void infer_shape() override; + + const char* name() const override { return "hccl_reduce"; } + VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; + DECLARE_jit_run; +}; + +} // jittor diff --git a/python/jittor/extern/acl/hccl/src/hccl_wrapper.cc b/python/jittor/extern/acl/hccl/src/hccl_wrapper.cc new file mode 100644 index 00000000..3c0eb691 --- /dev/null +++ b/python/jittor/extern/acl/hccl/src/hccl_wrapper.cc @@ -0,0 +1,60 @@ +// *************************************************************** +// Copyright (c) 2025 Jittor. +// All Rights Reserved. +// Maintainers: +// Jiapeng Zhang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** + +#include "hccl_wrapper.h" +#include "event_queue.h" +#include "acl_jittor.h" +#include + +namespace jittor { + +HcclRootInfo root_info; +HcclComm comm; +uint32_t hccl_device_id = 0; + +struct hccl_initer { + uint32_t device_count = 0; + hccl_initer() { + ACLCHECK(aclrtGetDeviceCount(&device_count)); + if (!device_count) return; + if (!inside_mpi) return; + hccl_device_id = mpi_local_rank; + if (mpi_local_rank >= device_count) { + LOGw << "mpi_local_rank(">>mpi_local_rank>>") is larger than device_count(" + >>device_count>>")"; + hccl_device_id = hccl_device_id % device_count; + } + LOGv << "HCCL init in device" << hccl_device_id << "local_rank" << mpi_local_rank; + //LOGir << aclstream; + //event_queue.run_sync([]() { + ACLCHECK(aclrtSetDevice(hccl_device_id)); + //}); + use_device_mpi = true; + LOGir << "HCCL init in device" << hccl_device_id << "local_rank" << mpi_local_rank; + if (mpi_world_rank == 0) + HCCLCHECK(HcclGetRootInfo(&root_info)); + MPI_CHECK(MPI_Bcast(&root_info, HCCL_ROOT_INFO_BYTES, MPI_CHAR, 0, MPI_COMM_WORLD)); + //MPI_Barrier(MPI_COMM_WORLD); + LOGir << "Count:" << device_count << "HCCL init in device" << hccl_device_id; + HCCLCHECK(HcclCommInitRootInfo(device_count, &root_info, hccl_device_id, &comm)); + ACLCHECK(aclrtCreateStream(&aclstream)); + LOGi << "HCCL init success in device" << hccl_device_id; + } + + ~hccl_initer() { + if (!device_count) return; + if (!inside_mpi) return; + if (!use_device_mpi) return; + HCCLCHECK(HcclCommDestroy(comm)); + } +}; + +static hccl_initer hccl_initer; +} \ No newline at end of file diff --git a/python/jittor/extern/mpi/ops/mpi_all_reduce_op.cc b/python/jittor/extern/mpi/ops/mpi_all_reduce_op.cc index 2ffc3bf8..96892b78 100644 --- a/python/jittor/extern/mpi/ops/mpi_all_reduce_op.cc +++ b/python/jittor/extern/mpi/ops/mpi_all_reduce_op.cc @@ -41,10 +41,18 @@ MpiAllReduceOp::MpiAllReduceOp(Var* x, NanoString op) : x(x), op(op) { static auto nccl_all_reduce = has_op("nccl_all_reduce") ? get_op_info("nccl_all_reduce").get_constructor() : nullptr; + static auto hccl_all_reduce = has_op("hccl_all_reduce") + ? get_op_info("hccl_all_reduce").get_constructor() + : nullptr; if (nccl_all_reduce) { auto var = nccl_all_reduce(x); forward(var); return; + } else if (hccl_all_reduce) { + auto var = hccl_all_reduce(x, "sum"); + //exe.run_sync({var}, true); + forward(var); + return; } } #endif diff --git a/python/jittor/extern/mpi/ops/mpi_broadcast_op.cc b/python/jittor/extern/mpi/ops/mpi_broadcast_op.cc index af8c1895..34efeb3c 100644 --- a/python/jittor/extern/mpi/ops/mpi_broadcast_op.cc +++ b/python/jittor/extern/mpi/ops/mpi_broadcast_op.cc @@ -26,10 +26,18 @@ MpiBroadcastOp::MpiBroadcastOp(Var* x, int root) : x(x), root(root) { static auto nccl_broadcast = has_op("nccl_broadcast") ? get_op_info("nccl_broadcast").get_constructor() : nullptr; + static auto hccl_broadcast = has_op("hccl_broadcast") + ? get_op_info("hccl_broadcast").get_constructor() + : nullptr; if (nccl_broadcast) { auto var = nccl_broadcast(x, root); forward(var); return; + } else if (hccl_broadcast) { + auto var = hccl_broadcast(x, root); + //exe.run_sync({var}, true); + forward(var); + return; } } #endif diff --git a/python/jittor/extern/mpi/ops/mpi_reduce_op.cc b/python/jittor/extern/mpi/ops/mpi_reduce_op.cc index 77d86a82..78294548 100644 --- a/python/jittor/extern/mpi/ops/mpi_reduce_op.cc +++ b/python/jittor/extern/mpi/ops/mpi_reduce_op.cc @@ -41,13 +41,18 @@ MpiReduceOp::MpiReduceOp(Var* x, NanoString op, int root) : x(x), op(op), root(r static auto nccl_reduce = has_op("nccl_reduce") ? get_op_info("nccl_reduce").get_constructor() : nullptr; + static auto hccl_reduce = has_op("hccl_reduce") + ? get_op_info("hccl_reduce").get_constructor() + : nullptr; if (nccl_reduce) { auto var = nccl_reduce(x, root); forward(var); return; + } else if (hccl_reduce) { + forward(var); + return; } } - #endif y = create_output(nullptr, x->dtype()); } diff --git a/python/jittor/extern/mpi/src/mpi_wrapper.cc b/python/jittor/extern/mpi/src/mpi_wrapper.cc index 498633a1..86b5abc1 100644 --- a/python/jittor/extern/mpi/src/mpi_wrapper.cc +++ b/python/jittor/extern/mpi/src/mpi_wrapper.cc @@ -8,6 +8,13 @@ // file 'LICENSE.txt', which is part of this source code package. // *************************************************************** #include +#include +#include +#include + +#if defined(__x86_64__) || defined(_M_X64) +#include +#endif #include #include #include @@ -31,8 +38,80 @@ namespace jittor { MPI_Datatype MPI_HALF; MPI_Op MPI_HALF_ADD; +#if !defined(__x86_64__) && !defined(_M_X64) +// ARM架构下的FP16-FP32转换辅助函数 +static inline float fp16_to_fp32_value(uint16_t h) { + unsigned sign = ((h >> 15) & 1); + unsigned exponent = ((h >> 10) & 0x1f); + unsigned mantissa = ((h & 0x3ff) << 13); + + if (exponent == 0) { + if (mantissa == 0) { + return sign ? -0.0f : 0.0f; + } else { + // 非规格化数 + while (!(mantissa & 0x400000)) { + mantissa <<= 1; + exponent -= 1; + } + exponent += 1; + mantissa &= ~0x400000; + } + } else if (exponent == 31) { + if (mantissa == 0) { + return sign ? -std::numeric_limits::infinity() : std::numeric_limits::infinity(); + } else { + return std::numeric_limits::quiet_NaN(); + } + } + + exponent += (127 - 15); + mantissa <<= 10; + + unsigned int i = ((sign << 31) | (exponent << 23) | mantissa); + float f; + std::memcpy(&f, &i, sizeof(float)); + return f; +} + +static inline uint16_t fp32_to_fp16_value(float f) { + unsigned int i; + std::memcpy(&i, &f, sizeof(float)); + + unsigned sign = ((i >> 31) & 0x1); + unsigned exponent = ((i >> 23) & 0xff); + unsigned mantissa = (i & 0x7fffff); + + unsigned short h = 0; + + if (exponent == 0) { + // 零或非规格化数 + h = (sign << 15); + } else if (exponent == 0xff) { + // 无穷大或NaN + h = (sign << 15) | 0x7c00; + if (mantissa) h |= 0x200; + } else { + // 规格化数 + int new_exp = exponent - 127 + 15; + if (new_exp < 0) { + // 下溢出到零 + h = (sign << 15); + } else if (new_exp > 30) { + // 上溢出到无穷大 + h = (sign << 15) | 0x7c00; + } else { + // 正常转换 + h = (sign << 15) | (new_exp << 10) | (mantissa >> 13); + } + } + + return h; +} +#endif + void HalfAdd(void* invec, void* inoutvec, int* len, MPI_Datatype* type) { - // return; +#if defined(__x86_64__) || defined(_M_X64) short* in = (short*)invec; short* inout = (short*)inoutvec; @@ -62,9 +141,27 @@ void HalfAdd(void* invec, void* inoutvec, int* len, MPI_Datatype* type) { // 将单精度浮点数转换回半精度浮点数,并存储结果 *(inout + i) = _mm_cvtps_ph(out, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)[0]; } +#else + // ARM架构实现:使用基本的半精度浮点数运算 + uint16_t* in = (uint16_t*)invec; + uint16_t* inout = (uint16_t*)inoutvec; + int total = *len; + + // 简单的逐元素相加实现 + for (int i = 0; i < total; i++) { + // 将FP16转换为FP32 + float in_val = fp16_to_fp32_value(in[i]); + float inout_val = fp16_to_fp32_value(inout[i]); + + // 执行加法 + float result = in_val + inout_val; + + // 将结果转回FP16 + inout[i] = fp32_to_fp16_value(result); + } +#endif } - int mpi_world_size = 1; int mpi_world_rank = 0; int mpi_local_size = 1; diff --git a/python/jittor/misc.py b/python/jittor/misc.py index 46ea7cea..e35b1429 100644 --- a/python/jittor/misc.py +++ b/python/jittor/misc.py @@ -625,10 +625,6 @@ def unique( #include #include - #include - #include - #include - #include #include ''', @@ -713,12 +709,9 @@ def unique( #include #include - #include #include #include - #include - @alias(input_sorted, in0) @alias(diff, in1) @alias(indice, in2) diff --git a/python/jittor/nn.py b/python/jittor/nn.py index 69a07847..6310cfcb 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -3536,3 +3536,36 @@ def mish(x, inplace=False): def skip_init(module_cls, *args, **kw): return module_cls(*args, **kw) + +class FlashAttention(Module): + def __init__(self,headnum, + layout = "BSND", + prefix = None, + qstart = None, + kvstart = None, + scale = 1.0, + prob = 1.0, + pretokens = 2147483647, + nexttokens = 2147483647, + innerprecise = 0, + sparsemode = 0, + psetype = 1): + self.headnum = headnum + self.layout = layout + self.prefix = prefix + self.qstart = qstart + self.kvstart = kvstart + self.scale = scale + self.prob = prob + self.pretokens = pretokens + self.nexttokens = nexttokens + self.innerprecise = innerprecise + self.sparsemode = sparsemode + self.psetype = psetype + + def execute(self,q,k,v, + realshift = None, + dropMask = None, + paddingMask = None, + attenMask = None): + pass diff --git a/python/jittor/src/common.h b/python/jittor/src/common.h index 8ca51367..58c6a37a 100644 --- a/python/jittor/src/common.h +++ b/python/jittor/src/common.h @@ -8,6 +8,7 @@ #include #include #include "utils/log.h" +#include "../extern/acl/aclnn/aclnn.h" #define JIT_TEST(name) extern void jit_test_ ## name () void expect_error(std::function func); diff --git a/python/jittor/src/jit_key.h b/python/jittor/src/jit_key.h index 7dd27292..b41ca887 100644 --- a/python/jittor/src/jit_key.h +++ b/python/jittor/src/jit_key.h @@ -79,6 +79,11 @@ struct JitKey { uint data; explicit dec2(uint data) : data(data) {} }; + + struct dec3 { + uint data; + explicit dec3(uint data) : data(data) {} + }; }; struct __jk_int128 { @@ -183,6 +188,16 @@ inline JK& operator<<(JK& jk, const JK::Oxhex2& h) { return jk << "0x" << JK::hex2(h.data); } +inline JK& operator<<(JK& jk, const JK::dec3& h) { + uint8 a = h.data % 10; + uint8 b = h.data / 10 % 10; + uint8 c = h.data / 100; + if (c) jk << (char)(c+'0'), jk << (char)(b+'0'); + else if (b) jk << (char)(b+'0'); + return jk << (char)(a+'0'); +} + + inline JK& operator<<(JK& jk, const JK::dec2& h) { uint8 a = h.data % 10; uint8 b = h.data / 10; @@ -195,6 +210,14 @@ inline JK& operator<<(JK& jk, const JK::dec1& h) { return jk << (char)(a+'0'); } +inline std::ostream& operator<<(std::ostream& os, const JK::dec3& h) { + uint8 a = h.data % 10; + uint8 b = h.data / 10 %10; + uint8 c = h.data / 100; + if (c) os << (char)(c+'0'), os << (char)(b+'0'); + else if (b) os << (char)(b+'0'); + return os << (char)(a+'0'); +} inline std::ostream& operator<<(std::ostream& os, const JK::dec2& h) { uint8 a = h.data % 10; uint8 b = h.data / 10; diff --git a/python/jittor/src/mem/allocator/cuda_dual_allocator.h b/python/jittor/src/mem/allocator/cuda_dual_allocator.h index b0be66c8..0cc94a31 100644 --- a/python/jittor/src/mem/allocator/cuda_dual_allocator.h +++ b/python/jittor/src/mem/allocator/cuda_dual_allocator.h @@ -30,7 +30,7 @@ EXTERN_LIB bool no_cuda_error_when_free; struct CudaDualAllocator : Allocator { //for recycle block_id - static const size_t ID_LIMIT = 1 << 20; + static const size_t ID_LIMIT = 1 << 16; int n_free_ids; int free_ids[ID_LIMIT]; DualAllocation allocations[ID_LIMIT]; diff --git a/python/jittor/src/mem/allocator/sfrl_allocator.cc b/python/jittor/src/mem/allocator/sfrl_allocator.cc index 22bf89b8..add3aef9 100644 --- a/python/jittor/src/mem/allocator/sfrl_allocator.cc +++ b/python/jittor/src/mem/allocator/sfrl_allocator.cc @@ -15,8 +15,8 @@ namespace jittor { DEFINE_FLAG(int, use_sfrl_allocator, 1, "Enable sfrl allocator"); -DEFINE_FLAG(int64, sfrl_large_block_size_device, 5242880, "sfrl_large_block_size, larger will reduce memory shard, only affect device"); -constexpr int64 sfrl_large_block_size_cpu=5242880; +DEFINE_FLAG(int64, sfrl_large_block_size_device, 20971520, "sfrl_large_block_size, larger will reduce memory shard, only affect device"); +constexpr int64 sfrl_large_block_size_cpu=20971520; std::vector CachingBlockPool::block_ids; //start from 1 diff --git a/python/jittor/src/mem/allocator/sfrl_allocator.h b/python/jittor/src/mem/allocator/sfrl_allocator.h index 3d409e41..2bd9a170 100644 --- a/python/jittor/src/mem/allocator/sfrl_allocator.h +++ b/python/jittor/src/mem/allocator/sfrl_allocator.h @@ -36,7 +36,7 @@ struct CachingBlockPool { //start from 1 static size_t tot_block_id; static std::unique_ptr occupied_id_mapper; - static const size_t ID_LIMIT = 1 << 21; + static const size_t ID_LIMIT = 1 << 18; pair get_key(CachingBlock* block); diff --git a/python/jittor/src/mem/allocator/temp_allocator.h b/python/jittor/src/mem/allocator/temp_allocator.h index ce2cd3b5..08dc0994 100644 --- a/python/jittor/src/mem/allocator/temp_allocator.h +++ b/python/jittor/src/mem/allocator/temp_allocator.h @@ -23,7 +23,7 @@ struct TempCachingBlock { struct TempAllocator : Allocator { static const size_t ALIGN_SIZE = 512; - static const size_t ID_LIMIT = 1 << 21; + static const size_t ID_LIMIT = 1 << 18; static vector temp_allocators; Allocator* underlying; size_t cache_blocks_limit, used_memory, unused_memory; diff --git a/python/jittor/src/misc/cuda_flags.cc b/python/jittor/src/misc/cuda_flags.cc index 999385fc..857eec6e 100644 --- a/python/jittor/src/misc/cuda_flags.cc +++ b/python/jittor/src/misc/cuda_flags.cc @@ -20,6 +20,8 @@ DEFINE_FLAG_WITH_SETTER(int, use_cuda, 0, "Use cuda or not. 1 for trying to use cuda, 2 for forcing to use cuda."); DEFINE_FLAG_WITH_SETTER(int, device_id, -1, "number of the device to used"); +DEFINE_FLAG_WITH_SETTER(int, sync_run, 1, + "Enable per-op-sync or not"); EXTERN_LIB void sync_all(bool device_sync); @@ -32,6 +34,10 @@ int get_device_count() { } #endif +void setter_sync_run(int value) { + if(sync_run == value) return; + sync_run = value; +} void setter_use_cuda(int value) { if (use_cuda == value) return; diff --git a/python/jittor/src/misc/cuda_flags.h b/python/jittor/src/misc/cuda_flags.h index 4e06897d..6c937690 100644 --- a/python/jittor/src/misc/cuda_flags.h +++ b/python/jittor/src/misc/cuda_flags.h @@ -14,6 +14,7 @@ namespace jittor { DECLARE_FLAG(int, use_cuda); +DECLARE_FLAG(int, sync_run); // @pyjt(get_device_count) int get_device_count(); diff --git a/python/jittor/src/misc/miniz.h b/python/jittor/src/misc/miniz.h index 1d328456..d3445075 100755 --- a/python/jittor/src/misc/miniz.h +++ b/python/jittor/src/misc/miniz.h @@ -1341,8 +1341,8 @@ struct ZipFile { zip_archive = nullptr; } } - // if (!zip_archive) - // throw std::runtime_error("Failed to open zip file: " + filename); + if (!zip_archive) + throw std::runtime_error("Failed to open zip file: " + filename); } // @pyjt(__dealloc__) inline ~ZipFile() { diff --git a/python/jittor/src/ops/array_op.cc b/python/jittor/src/ops/array_op.cc index 0841a288..23b0f38b 100644 --- a/python/jittor/src/ops/array_op.cc +++ b/python/jittor/src/ops/array_op.cc @@ -27,17 +27,19 @@ namespace array_local { cudaStream_t stream; cudaEvent_t event; + struct Init { Init() { if (!get_device_count()) return; - checkCudaErrors(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); - checkCudaErrors(cudaEventCreate(&event, cudaEventDisableTiming)); + //checkCudaErrors(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); + //checkCudaErrors(cudaEventCreate(&event, cudaEventDisableTiming)); + stream = aclstream; } ~Init() { if (!get_device_count()) return; - peekCudaErrors(cudaDeviceSynchronize()); - peekCudaErrors(cudaStreamDestroy(stream)); - peekCudaErrors(cudaEventDestroy(event)); + //peekCudaErrors(cudaDeviceSynchronize()); + //peekCudaErrors(cudaStreamDestroy(stream)); + //peekCudaErrors(cudaEventDestroy(event)); } } init; @@ -102,8 +104,11 @@ void ArrayOp::run() { auto host_ptr = cuda_dual_allocator.get_dual_allocation(allocation.allocation).host_ptr; checkCudaErrors(cudaMemcpyAsync( allocation.ptr, host_ptr, allocation.size, cudaMemcpyHostToDevice, stream)); - checkCudaErrors(cudaEventRecord(event, stream)); - checkCudaErrors(cudaStreamWaitEvent(0, event, 0)); + // checkCudaErrors(aclrtMemcpyAsync( + // allocation.ptr, allocation.size, host_ptr, allocation.size, cudaMemcpyHostToDevice, aclstream)); + // checkCudaErrors(cudaEventRecord(event, stream)); + // checkCudaErrors(cudaStreamWaitEvent(0, event, 0)); + // checkCudaErrors(aclrtSynchronizeStream(aclstream)); // delay free this allocation allocation.allocator = &delay_free; } diff --git a/python/jittor/src/ops/code_op.cc b/python/jittor/src/ops/code_op.cc index 021b615e..144598ed 100644 --- a/python/jittor/src/ops/code_op.cc +++ b/python/jittor/src/ops/code_op.cc @@ -106,13 +106,13 @@ VarPtr CodeOp::grad(Var* out, Var* dout, Var* v, int v_index) { // TODO: remove unused deps // dout -> dout std::stringstream new_alias; - new_alias << "\n@alias(dout,in" << JK::dec2(inputs.size()) << ")\n"; + new_alias << "\n@alias(dout,in" << JK::dec3(inputs.size()) << ")\n"; inputs.push_back(dout); // _outputs[i] -> poutj for (int i=0; i<_outputs.size(); i++) { - new_alias << "\n@alias(pout" << JK::dec2(i) << ",in" << JK::dec2(inputs.size()) << ")\n"; + new_alias << "\n@alias(pout" << JK::dec3(i) << ",in" << JK::dec3(inputs.size()) << ")\n"; if (_outputs[i] == out) - new_alias << "\n@alias(pout,in" << JK::dec2(inputs.size()) << ")\n"; + new_alias << "\n@alias(pout,in" << JK::dec3(inputs.size()) << ")\n"; inputs.push_back(_outputs[i]); } auto alias = new_alias.str(); @@ -130,18 +130,19 @@ void CodeOp::jit_prepare(JK& jk) { // forward: in0 in1 in2 -> out0 out1 // backward: in0 in1 in2 in3(pout0) in4(pout1) - jk << "«IN_SIZE:" << JK::dec2(_inputs.size()); + jk << "«IN_SIZE:" << JK::dec3(_inputs.size()); for (uint i=0; i<_inputs.size(); i++) { - jk << "«in" << JK::dec2(i) << "_dim:" + //LOGir<shape.size()); - jk << "«in" << JK::dec2(i) << "_type:" + jk << "«in" << JK::dec3(i) << "_type:" << _inputs[i]->dtype(); } - jk << "«OUT_SIZE:" << JK::dec2(_outputs.size()); + jk << "«OUT_SIZE:" << JK::dec3(_outputs.size()); for (uint i=0; i<_outputs.size(); i++) { - jk << "«out" << JK::dec2(i) << "_dim:" + jk << "«out" << JK::dec3(i) << "_dim:" << JK::hex1(_outputs[i]->shape.size()); - jk << "«out" << JK::dec2(i) << "_type:" + jk << "«out" << JK::dec3(i) << "_type:" << _outputs[i]->dtype(); } string& header = flags.get(NodeFlags::_cuda) ? diff --git a/python/jittor/src/ops/copy_op.cc b/python/jittor/src/ops/copy_op.cc index ebc748cf..b48e57f5 100644 --- a/python/jittor/src/ops/copy_op.cc +++ b/python/jittor/src/ops/copy_op.cc @@ -17,6 +17,8 @@ namespace jittor { +EXTERN_LIB aclrtStream aclstream; + CopyOp::CopyOp(Var* x) { flags.set(NodeFlags::_cpu); flags.set(NodeFlags::_cuda); @@ -42,6 +44,8 @@ void CopyOp::run() { #ifdef HAS_CUDA if (flags.get(NodeFlags::_cuda)) { checkCudaErrors(cudaMemcpyAsync(y_ptr, x_ptr, size, cudaMemcpyDeviceToDevice, 0)); + // checkCudaErrors(aclrtMemcpyAsync(y_ptr, size, x_ptr, size, cudaMemcpyDeviceToDevice, aclstream)); + // checkCudaErrors(aclrtSynchronizeStream(aclstream)); } else #endif { diff --git a/python/jittor/src/ops/fetch_op.cc b/python/jittor/src/ops/fetch_op.cc index 634f4f07..98de48d3 100644 --- a/python/jittor/src/ops/fetch_op.cc +++ b/python/jittor/src/ops/fetch_op.cc @@ -47,6 +47,7 @@ Init() { if (!get_device_count()) return; checkCudaErrors(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); checkCudaErrors(cudaEventCreate(&event, cudaEventDisableTiming)); + stream = aclstream; } ~Init() { if (!get_device_count()) return; @@ -122,8 +123,11 @@ void FetchOp::run() { new (&allocation) Allocation(&cuda_dual_allocator, v->size); // mostly device to device #if IS_CUDA + // checkCudaErrors(cudaMemcpyAsync( + // allocation.ptr, v->mem_ptr, v->size, cudaMemcpyDefault, stream)); checkCudaErrors(cudaMemcpyAsync( - allocation.ptr, v->mem_ptr, v->size, cudaMemcpyDefault, stream)); + allocation.ptr, v->size, v->mem_ptr, v->size, cudaMemcpyDefault, aclstream)); + checkCudaErrors(aclrtSynchronizeStream(aclstream)); #else checkCudaErrors(cudaMemcpyAsync( allocation.ptr, v->mem_ptr, v->size, cudaMemcpyDeviceToDevice, stream)); @@ -131,8 +135,11 @@ void FetchOp::run() { auto host_ptr = cuda_dual_allocator.get_dual_allocation( allocation.allocation).host_ptr; // device to host - checkCudaErrors(cudaMemcpyAsync( - host_ptr, allocation.ptr, v->size, cudaMemcpyDeviceToHost, stream)); + // checkCudaErrors(cudaMemcpyAsync( + // host_ptr, allocation.ptr, v->size, cudaMemcpyDeviceToHost, stream)); + checkCudaErrors(aclrtMemcpyAsync( + host_ptr, v->size, allocation.ptr, v->size, cudaMemcpyDeviceToHost, aclstream)); + checkCudaErrors(aclrtSynchronizeStream(aclstream)); allocation.ptr = host_ptr; has_cuda_memcpy = true; } else diff --git a/python/jittor/src/ops/setitem_op.cc b/python/jittor/src/ops/setitem_op.cc index b8d8d8e2..2ba818ac 100644 --- a/python/jittor/src/ops/setitem_op.cc +++ b/python/jittor/src/ops/setitem_op.cc @@ -330,6 +330,8 @@ void SetitemOp::jit_run() { #else if (op != ip) checkCudaErrors(cudaMemcpyAsync(op, ip, out->size, cudaMemcpyDeviceToDevice, 0)); + // checkCudaErrors(aclrtMemcpyAsync(op, out->size, ip, out->size, cudaMemcpyDeviceToDevice, aclstream)); + // checkCudaErrors(aclrtSynchronizeStream(aclstream)); #endif if (ns.get(GetitemOp::_inplace) && diff --git a/python/jittor/src/pyjt/py_converter.h b/python/jittor/src/pyjt/py_converter.h index 5101b411..f8ac109a 100644 --- a/python/jittor/src/pyjt/py_converter.h +++ b/python/jittor/src/pyjt/py_converter.h @@ -892,7 +892,7 @@ void load_var_slice(PyObject* obj, T* var_slice, vector>& auto* vh = from_py_object(obj, holders.back()); auto vv = (decltype(var_slice->var)*)vh; CHECK(vv[0]->dtype() != ns_bool) << "Please convert bool slice into jt.array, example:\n" - "a[[True,False,False]] ---> a[jt.array([True,False,False])"; + "a[[True,False,False]] ---> a[jt.array([True,False,False])]"; var_slice->set_var(vv[0]); } } diff --git a/python/jittor/test/test_acl.py b/python/jittor/test/test_acl.py index 4b72fbae..61d271a2 100644 --- a/python/jittor/test/test_acl.py +++ b/python/jittor/test/test_acl.py @@ -1,6 +1,6 @@ # *************************************************************** -# Copyright (c) 2023 Jittor. All Rights Reserved. -# Maintainers: Dun Liang . +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . # This file is subject to the terms and conditions defined in # file 'LICENSE.txt', which is part of this source code package. # *************************************************************** @@ -11,27 +11,28 @@ import numpy as np from jittor import init, Module import numpy as np - @unittest.skipIf(not jt.compiler.has_acl, "No ACL found") class TestACL(unittest.TestCase): @jt.flag_scope(use_acl=1) def test_array(self): - print("use_acl", jt.flags.use_acl) - a = jt.array([1,2,3]) - np.testing.assert_allclose(a.numpy(), [1,2,3]) + a = jt.array([1, 2, 3]) + np.testing.assert_allclose(a.numpy(), [1, 2, 3]) + print('test_array pass') @jt.flag_scope(use_acl=1) def test_add(self): - a = jt.array([1,2,3]) - b = a+a - np.testing.assert_allclose(b.numpy(), [2,4,6]) + a = jt.array([1, 2, 3]) + b = a + a + np.testing.assert_allclose(b.numpy(), [2, 4, 6]) + print('test_add pass') @jt.flag_scope(use_acl=1) def test_add_float(self): - a = jt.array([1.0,2.0,3.0]) - b = a+a - np.testing.assert_allclose(b.numpy(), [2,4,6]) + a = jt.array([1.0, 2.0, 3.0]) + b = a + a + np.testing.assert_allclose(b.numpy(), [2, 4, 6]) + print('test_add_float pass') @jt.flag_scope(use_acl=1) def test_array_cast(self): @@ -39,6 +40,7 @@ class TestACL(unittest.TestCase): x = np.random.rand(10) y = jt.float32(x) np.testing.assert_allclose(x, y.numpy()) + print('test_array_cast pass') @jt.flag_scope(use_acl=1) def test_array_cast_half(self): @@ -46,63 +48,67 @@ class TestACL(unittest.TestCase): x = np.random.rand(10).astype("float32") y = jt.float16(x) np.testing.assert_allclose(x.astype("float16"), y.numpy()) + print('test_array_cast_half pass') @jt.flag_scope(use_acl=1) def test_rand(self): a = jt.rand(10) - b = a*10 + b = a * 10 b.sync() print(b) def test_meminfo(self): jt.display_memory_info() + print('test_meminfo pass') @jt.flag_scope(use_acl=1) def test_conv(self): x = jt.rand(10, 3, 50, 50) - w = jt.rand(4,3,3,3) + w = jt.rand(4, 3, 3, 3) # x = jt.rand(2, 2, 1, 1) # w = jt.rand(2,2,1,1) y = jt.nn.conv2d(x, w) y.sync(True) y1 = y.data mask = jt.rand_like(y) - dx, dw = jt.grad((y*mask).sum(), [x, w]) + dx, dw = jt.grad((y * mask).sum(), [x, w]) dx1, dw1 = dx.data, dw.data # dw, = jt.grad((y*mask).sum(), [w]) # dw1 = dw.data with jt.flag_scope(use_acl=0): y = jt.nn.conv2d(x, w) y2 = y.data - dx, dw = jt.grad((y*mask).sum(), [x, w]) + dx, dw = jt.grad((y * mask).sum(), [x, w]) dx2, dw2 = dx.data, dw.data # dw, = jt.grad((y*mask).sum(), [w]) # dw2 = dw.data np.testing.assert_allclose(y1, y2) np.testing.assert_allclose(dx1, dx2) np.testing.assert_allclose(dw1, dw2) + print('test_conv pass') @jt.flag_scope(use_acl=1) def test_matmul(self): # x = jt.rand(10, 3, 50, 50) # w = jt.rand(4,3,3,3) - x = jt.rand(10,10) - w = jt.rand(10,10) + x = jt.rand(10, 10) + w = jt.rand(10, 10) y = jt.matmul(x, w) ny = np.matmul(x.numpy(), w.numpy()) np.testing.assert_allclose(y.numpy(), ny, atol=1e-3, rtol=1e-3) - # y.sync(True) + print('test_matmul pass') @jt.flag_scope(use_acl=1) def test_max(self): - x = jt.rand(3,3) + x = jt.rand(3, 3) y = x.max(1).data ny = x.data.max(1) np.testing.assert_allclose(y, ny) + print('test_max pass') @jt.flag_scope(use_acl=1) def test_sum(self): - x = jt.rand(3,3).float16() + x = jt.rand(3, 3).float16() print(x) # return y = x.sum(1).data @@ -110,65 +116,71 @@ class TestACL(unittest.TestCase): print(x) ny = x.data.sum(1) np.testing.assert_allclose(y, ny) + print('test_sum pass') @jt.flag_scope(use_acl=1) def test_broadcast(self): x = jt.rand(3) # print(x) - y = x.broadcast([3,3]).data + y = x.broadcast([3, 3]).data ny = np.broadcast_arrays(x.data, y)[0] np.testing.assert_allclose(y, ny) print(x, y) # y = x.broadcast([3,3], dims=[1]).data - y = jt.broadcast(x, shape=(3,3), dims=[1]).data + y = jt.broadcast(x, shape=(3, 3), dims=[1]).data with jt.flag_scope(use_acl=0): - ny = jt.broadcast(x, shape=(3,3), dims=[1]).data + ny = jt.broadcast(x, shape=(3, 3), dims=[1]).data # ny = np.broadcast_arrays(x.data, y)[0] np.testing.assert_allclose(y, ny) print(x, y) + print('test_broadcast pass') @jt.flag_scope(use_acl=1) def test_resnet(self): from jittor.models import resnet50 net = resnet50() - x = jt.rand(2,3,224,224) + x = jt.rand(2, 3, 224, 224) y = net(x) y.sync() - -def matmul(a, b): - (n, m), k = a.shape, b.shape[-1] - a = a.broadcast([n,m,k], dims=[2]) - b = b.broadcast([n,m,k], dims=[0]) - return (a*b).sum(dim=1) - class Linear(Module): + def __init__(self, in_features, out_features, bias=True): - self.w = (jt.random((in_features, out_features))-0.5) / in_features**0.5 - self.b = jt.random((out_features,))-0.5 if bias else None + self.w = (jt.random( + (in_features, out_features), type='normal') - 0.5) / in_features**0.5 + self.b = jt.random((out_features, ), type='normal') - 0.5 if bias else None + def execute(self, x): - x = matmul(x, self.w) - if self.b is not None: - return x+self.b + x = jt.nn.matmul(x, self.w) + if self.b is not None: + return x + self.b return x + def relu(x): return jt.maximum(x, 0.0) + + Relu = jt.make_module(relu) + class Model(Module): + def __init__(self, input_size): self.linear1 = Linear(input_size, 10) self.relu1 = Relu() self.linear2 = Linear(10, 1) + def execute(self, x): x = self.linear1(x) x = self.relu1(x) return self.linear2(x) + @unittest.skipIf(not jt.compiler.has_acl, "No ACL found") class TestExample(unittest.TestCase): + @jt.flag_scope(use_acl=1) def test1(self): np.random.seed(0) @@ -180,26 +192,28 @@ class TestExample(unittest.TestCase): def get_data(n): for i in range(n): x = np.random.rand(batch_size, 1).astype("float32") - y = x*x + y = x * x yield jt.float32(x), jt.float32(y) - + model = Model(input_size=1) ps = model.parameters() - for i,(x,y) in enumerate(get_data(n)): + for i, (x, y) in enumerate(get_data(n)): jt.sync_all(True) pred_y = model(x).name("pred_y") loss = ((pred_y - y).sqr()).name("loss") loss_mean = loss.mean() - + gs = jt.grad(loss_mean, ps) for p, g in zip(ps, gs): p -= g * lr - - if i>2: - assert prev == jt.liveness_info(), f"memory leak {prev} {jt.liveness_info()}" + if i > 2: + assert prev == jt.liveness_info( + ), f"memory leak {prev} {jt.liveness_info()}" prev = jt.liveness_info() - print(f"step {i}, loss = {loss_mean.data.sum()} {jt.liveness_info()}") + print( + f"step {i}, loss = {loss_mean.data.sum()} {jt.liveness_info()}" + ) possible_results = [ 0.0009948202641680837, @@ -211,5 +225,6 @@ class TestExample(unittest.TestCase): jt.clean() + if __name__ == "__main__": - unittest.main() + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_aclop.py b/python/jittor/test/test_aclop.py index 41a6d748..9b5cb9ae 100644 --- a/python/jittor/test/test_aclop.py +++ b/python/jittor/test/test_aclop.py @@ -1,54 +1,167 @@ import unittest import jittor as jt -from .test_core import expect_error -import numpy as np -from jittor import init, Module import numpy as np +import time @unittest.skipIf(not jt.compiler.has_acl, "No ACL found") class TestACL(unittest.TestCase): + def setUp(self): + self.repeat_num = 10 + + def measure_time(self, func): + # warm up + for _ in range(5): + result = func() + if isinstance(result, list) or isinstance(result, tuple): + for i in result: + i.sync() + else: + result.sync() + + start_time = time.perf_counter() + for _ in range(self.repeat_num): + result = func() + if isinstance(result, list) or isinstance(result, tuple): + for i in result: + i.sync() + else: + result.sync() + jt.sync_all(True) + end_time = time.perf_counter() + elapsed = (end_time - start_time) / self.repeat_num + print(f"{self.id()} executed in {1000*elapsed:.6f} ms") + return result + @jt.flag_scope(use_acl=1) - def test_getitem(self): + def test_getitem_1(self): a = jt.ones(100, 2) - b = a[0:2, 0:2] + b = self.measure_time(lambda: a[0:2, 0:2]) np.testing.assert_allclose(b.numpy(), [[1, 1], [1, 1]]) - print("test getitem success") + print("test getitem (test case 1) success") @jt.flag_scope(use_acl=1) - def test_getitem_neg(self): - a = jt.ones(2, 3, 2) - b = a[0:1,0:-2] - np.testing.assert_allclose(b.numpy(), [[[1,1]]]) - print("test getitem neg success") + def test_getitem_2(self): + a = jt.ones((2, 3)) + b = self.measure_time(lambda: a[:, None, :]) + assert b.shape == [2, 1, 3] + print("test getitem (test case 2) success") @jt.flag_scope(use_acl=1) - def test_setitem(self): + def test_getitem_3(self): + a = jt.ones((2, 3, 4, 5, 10)) + b = self.measure_time(lambda: a[1, ...]) + assert b.shape == [3, 4, 5, 10] + print("test getitem (test case 3) success") + + @jt.flag_scope(use_acl=1) + def test_getitem_4(self): + a = jt.ones((2, 3, 4, 5, 10)) + b = self.measure_time(lambda: a[..., :2]) + assert b.shape == [2, 3, 4, 5, 2] + print("test getitem (test case 4) success") + + @jt.flag_scope(use_acl=1) + def test_getitem_5(self): + a = jt.ones((2, 3, 4, 5, 10)) + b = self.measure_time(lambda: a[1, ..., :2]) + assert b.shape == [3, 4, 5, 2] + print("test getitem (test case 5) success") + + @jt.flag_scope(use_acl=1) + def test_getitem_6(self): + a = jt.ones((2, 3, 4, 5, 10)) + b = self.measure_time(lambda: a[1, None, :, :, :, :2]) + assert b.shape == [1, 3, 4, 5, 2] + print("test getitem (test case 6) success") + + @jt.flag_scope(use_acl=1) + def test_getitem_7(self): + a = jt.ones((2, 3, 4, 5, 10)) + b = self.measure_time(lambda: a[1, 2, None, :, :, :2]) + assert b.shape == [1, 4, 5, 2] + print("test getitem (test case 7) success") + + @jt.flag_scope(use_acl=1) + def test_getitem_8(self): + a = jt.ones((2, 3, 4, 5, 10)) + b = self.measure_time(lambda: a[1, 2, None, :, :, None, :2]) + assert b.shape == [1, 4, 5, 1, 2] + print("test getitem (test case 8) success") + + @jt.flag_scope(use_acl=1) + def test_getitem_9(self): + a = jt.ones((2, 3, 4, 5, 10)) + b = self.measure_time(lambda: a[None, ..., None]) + assert b.shape == [1, 2, 3, 4, 5, 10, 1] + print("test getitem (test case 9) success") + + @jt.flag_scope(use_acl=1) + def test_getitem_10(self): + a = jt.ones((2, 3, 4, 5, 10)) + b = self.measure_time(lambda: a[None, ..., None, 1]) + assert b.shape == [1, 2, 3, 4, 5, 1] + print("test getitem (test case 10) success") + + @jt.flag_scope(use_acl=1) + def test_getitem_11(self): + a = jt.ones(10) + b = self.measure_time(lambda: a[2:]) + assert b.shape == [8] + print("test getitem (test case 11) success") + + @jt.flag_scope(use_acl=1) + def test_getitem_12(self): + a = jt.array([[1,2,3], [4,5,6], [7,8,9]]) + b = self.measure_time(lambda: a[[0,1,1]]) + np.testing.assert_allclose(b.numpy(), [[1, 2, 3], [4, 5, 6], [4, 5, 6]]) + print("test getitem (test case 12) success") + + @jt.flag_scope(use_acl=1) + def test_getitem_13(self): + a = jt.array([[1,2,3], [4,5,6], [7,8,9]]) + index = jt.array([0,1,1]) + b = self.measure_time(lambda: a[index]) + np.testing.assert_allclose(b.numpy(), [[1, 2, 3], [4, 5, 6], [4, 5, 6]]) + print("test getitem (test case 13) success") + + @jt.flag_scope(use_acl=1) + def test_getitem_14(self): + a = jt.array([[1, 2], [3, 4]]) + index = jt.array([[False,True],[True, False]]) + b = self.measure_time(lambda: a[index]) + np.testing.assert_allclose(b.numpy(), [2, 3]) + print("test getitem (test case 14) success") + + @jt.flag_scope(use_acl=1) + def test_setitem_1(self): a = jt.ones(2, 2) - b = jt.Var(0) - a[0:1, 0:1] = b + a[0:1, 0:1] = 0 np.testing.assert_allclose(a.numpy(), [[0, 1], [1, 1]]) - print("test setitem success") - + print("test setitem (test case 1) success") + + # @jt.flag_scope(use_acl=1) + # def test_setitem_2(self): + # a = jt.ones(2, 2) + # b = jt.Var(0) + # a[0:1, 0:1] = b + # np.testing.assert_allclose(a.numpy(), [[0, 1], [1, 1]]) + # print("test setitem (test case 2) success") + @jt.flag_scope(use_acl=1) - def test_setitem_neg(self): - a = jt.ones(2, 3, 2) - b = jt.Var(0) - a[0:1, 0:-2] = b - np.testing.assert_allclose(a.numpy(), [[[0,0],[1,1],[1,1]],[[1,1],[1,1],[1,1]]]) - print("test setitem neg success") + def test_setitem_3(self): + a = jt.array([[1, 2], [3, 4]]) + index = jt.array([[False,True],[True, False]]) + a[index] = 5 + np.testing.assert_allclose(a.numpy(), [[1, 5], [5, 4]]) + print("test setitem (test case 3) success") @jt.flag_scope(use_acl=1) def test_getitem_grad(self): a = jt.ones(2, 2) b = a[0:1, 0:1] - optimizer = jt.optim.SGD([a], 0.1) - loss = b.sum() - optimizer.zero_grad() - optimizer.backward(loss) - optimizer.step() - res = a.opt_grad(optimizer) + res = self.measure_time(lambda: jt.grad(b.sum(), a)) np.testing.assert_allclose(res.numpy(), [[1, 0], [0, 0]]) print("test getitem grad success") @@ -57,13 +170,8 @@ class TestACL(unittest.TestCase): a = jt.ones(3, 3) b = jt.ones(2, 2) a[0:2, 0:2] = b * 2 - optimizer = jt.optim.SGD([a, b], 0.1) - loss = a.sum() - optimizer.zero_grad() - optimizer.backward(loss) - optimizer.step() - res_a = a.opt_grad(optimizer) - res_b = b.opt_grad(optimizer) + res_a = self.measure_time(lambda: jt.grad(a.sum(), a)) + res_b = self.measure_time(lambda: jt.grad(a.sum(), b)) np.testing.assert_allclose(res_a.numpy(), [[1, 1, 1], [1, 1, 1], [1, 1, 1]]) np.testing.assert_allclose(res_b.numpy(), [[2, 2], [2, 2]]) @@ -73,232 +181,306 @@ class TestACL(unittest.TestCase): def test_concat(self): a = jt.ones(2, 2) b = jt.ones(2, 2) - c = jt.concat([a, b], 0) + c = self.measure_time(lambda: jt.concat([a, b], 0)) np.testing.assert_allclose(c.numpy(), [[1, 1], [1, 1], [1, 1], [1, 1]]) print("test concat success") - @jt.flag_scope(use_acl=1) - def test_concat_neg(self): - a = jt.ones(2, 2) - b = jt.ones(2, 2) - c = jt.concat([a, b], -1) - np.testing.assert_allclose(c.numpy(), [[1,1,1,1],[1,1,1,1]]) - print("test concat neg success") - - @jt.flag_scope(use_acl=1) - def test_concat_zero_dim(self): - a = jt.ones([]) - b = jt.zeros([]) - c = jt.concat([a, b], 0) - np.testing.assert_allclose(c.numpy(), [1,0]) - print("test concat zero dim success") - @jt.flag_scope(use_acl=1) def test_maxpool_grad(self): - a = jt.float32([[[[1,2,3,4],[2,3,4,1],[3,4,1,2],[4,1,2,3]]]]) + a = jt.ones(1, 1, 4, 4) max_pool = jt.nn.Pool(2, op='maximum') - optimizer = jt.optim.SGD([a], 0.1) - b = max_pool(a) - loss = b.sum() - optimizer.zero_grad() - optimizer.backward(loss) - optimizer.step() - res = a.opt_grad(optimizer) + + def tmp_func(a): + b = max_pool(a) + res = jt.grad(b.sum(), a) + return res + + res = self.measure_time(lambda: tmp_func(a)) np.testing.assert_allclose( res.numpy(), - [[[[0, 0, 0, 1], [0, 1, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1]]]]) + [[[[1, 0, 1, 0], [0, 0, 0, 0], [1, 0, 1, 0], [0, 0, 0, 0]]]]) print("test maxpool grad success") @jt.flag_scope(use_acl=1) def test_triu(self): a = jt.ones(3, 3) - b = jt.triu_(a, 0) - c = jt.triu_(a, 1) - d = jt.triu_(a, -1) + b = self.measure_time(lambda: jt.triu_(a, 0)) + c = self.measure_time(lambda: jt.triu_(a, 1)) np.testing.assert_allclose(b.numpy(), [[1, 1, 1], [0, 1, 1], [0, 0, 1]]) np.testing.assert_allclose(c.numpy(), [[0, 1, 1], [0, 0, 1], [0, 0, 0]]) - np.testing.assert_allclose(d.numpy(), - [[1, 1, 1], [1, 1, 1], [0, 1, 1]]) print("test triu success") - @jt.flag_scope(use_acl=1) - def test_bmm(self): - a = jt.float32([[[1,2],[3,4]],[[2,1],[4,3]],[[1,2],[4,3]]]) - b = jt.bmm(a, a) - np.testing.assert_allclose( - b.numpy(), [[[7, 10], [15, 22]], [[8, 5], [20, 13]], [[9, 8], [16, 17]]]) - print("test bmm success") - - @jt.flag_scope(use_acl=1) - def test_matmul(self): - a = jt.float32([[[1,2,3,4],[5,6,7,8],[9,10,11,12],[13,14,15,16]]]) - b = jt.float32([[1,1],[1,1],[1,1],[1,1]]) - c = jt.matmul(a, b) - np.testing.assert_allclose(c.numpy(), - [[[10, 10], [26, 26], [42, 42], [58, 58]]]) - print("test matmul success") - @jt.flag_scope(use_acl=1) def test_maxpool(self): - a = jt.float32([[[[1,2,3,4],[2,3,4,1],[3,4,1,2],[4,1,2,3]]]]) + a = jt.ones(1, 1, 4, 4) max_pool = jt.nn.Pool(2, op='maximum') - np.testing.assert_allclose(max_pool(a).numpy(), [[[[3, 4], [4, 3]]]]) + b = self.measure_time(lambda: max_pool(a)) + np.testing.assert_allclose(b.numpy(), [[[[1, 1], [1, 1]]]]) print("test maxpool success") @jt.flag_scope(use_acl=1) def test_transpose(self): - a = jt.float32([[[1,2],[3,4]]]) - b = a.transpose(0, 2) - np.testing.assert_allclose(b.numpy(), [[[1], [3]], [[2], [4]]]) + a = jt.ones(1, 2, 2) + b = self.measure_time(lambda: a.transpose(0, 2)) + np.testing.assert_allclose(b.numpy(), [[[1], [1]], [[1], [1]]]) print("test transpose success") @jt.flag_scope(use_acl=1) - def test_transpose_neg(self): - a = jt.float32([[[1,2],[3,4]]]) - b = a.transpose(1, -1) - np.testing.assert_allclose(b.numpy(), [[[1,3], [2,4]]]) - print("test transpose neg success") - - @jt.flag_scope(use_acl=1) - def test_matmul_grad(self): - a = jt.float32([[[1,2,3,4],[5,6,7,8],[9,10,11,12],[13,14,15,16]]]) - b = jt.float32([[1,1],[1,1],[1,1],[1,1]]) - optimizer = jt.optim.SGD([a, b], 0.1) - loss = jt.matmul(a, b).sum() - optimizer.zero_grad() - optimizer.backward(loss) - optimizer.step() - res_a = a.opt_grad(optimizer) - res_b = b.opt_grad(optimizer) - np.testing.assert_allclose(res_a.numpy(), [[[2, 2, 2, 2], [2, 2, 2, 2], [2, 2, 2, 2], [2, 2, 2, 2]]]) - np.testing.assert_allclose(res_b.numpy(), [[28, 28], [32, 32], [36, 36], [40, 40]]) - print("test matmul grad success") - - @jt.flag_scope(use_acl=1) - def test_bmm_grad(self): - a = jt.float32([[[1,2],[3,4]],[[2,1],[4,3]],[[1,2],[4,3]]]) - optimizer = jt.optim.SGD([a], 0.1) - c = jt.bmm(a, a) - loss = c.sum() - - optimizer.zero_grad() - optimizer.backward(loss) - optimizer.step() - - res = a.opt_grad(optimizer) + def test_matmul_1(self): + a = jt.arange(16).reshape(1, 4, 4).float() + b = jt.arange(8).reshape(4, 2).float() + f = self.measure_time(lambda: jt.matmul(a, b)) np.testing.assert_allclose( - res.numpy(), - [[[7, 11], [9, 13]], [[9, 13], [7, 11]], [[8, 12], [8, 12]]]) - print("test bmm grad success") + f.numpy(), [[[28, 34], [76, 98], [124, 162], [172, 226]]]) + print("test matmul_1 success") @jt.flag_scope(use_acl=1) - def test_avgpool(self): - a = jt.float32([[[[1,2,3,4],[2,3,4,1],[3,4,1,2],[4,1,2,3]]]]) - avg_pool = jt.nn.Pool(2, op='mean') - b = avg_pool(a) - np.testing.assert_allclose(b.numpy(), [[[[2, 3], [3, 2]]]]) - print("test avgpool success") - - @jt.flag_scope(use_acl=1) - def test_adaptive_maxpool2d(self): - a = jt.float32([[[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], - [13, 14, 15, 16]]]]) - pool_1 = jt.nn.AdaptiveMaxPool2d((2, 2)) - pool_2 = jt.nn.AdaptiveMaxPool2d((3, 4)) - b = pool_1(a) - c = pool_2(a) - np.testing.assert_allclose(b.numpy(), [[[[6, 8], [14, 16]]]]) - np.testing.assert_allclose(c.numpy(), [[[[5,6,7,8],[9,10,11,12],[13,14,15,16]]]]) - print("test adaptive_maxpool2d success") - - @jt.flag_scope(use_acl=1) - def test_adaptive_maxpool2d_grad_1(self): - a = jt.float32([[[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], - [13, 14, 15, 16]]]]) - max_pool = jt.nn.AdaptiveMaxPool2d((2, 2)) - optimizer = jt.optim.SGD([a], 0.1) - b = max_pool(a) - loss = b.sum() - optimizer.zero_grad() - optimizer.backward(loss) - optimizer.step() - res = a.opt_grad(optimizer) + def test_matmul_2(self): + a = jt.arange(16).reshape(1, 4, 4).float() + c = jt.arange(8).reshape(2, 4).float() + g = self.measure_time(lambda: jt.nn.matmul_transpose(a, c)) np.testing.assert_allclose( - res.numpy(), - [[[[0, 0, 0, 0], [0, 1, 0, 1], [0, 0, 0, 0], [0, 1, 0, 1]]]]) - print("test adaptive_maxpool2d_1 grad success") + g.numpy(), [[[14, 38], [38, 126], [62, 214], [86, 302]]]) + print("test matmul_2 success") @jt.flag_scope(use_acl=1) - def test_adaptive_maxpool2d_grad_2(self): - a = jt.float32([[[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], - [13, 14, 15, 16]]]]) - max_pool = jt.nn.AdaptiveMaxPool2d((1, 3)) - optimizer = jt.optim.SGD([a], 0.1) - b = max_pool(a) - loss = b.sum() - optimizer.zero_grad() - optimizer.backward(loss) - optimizer.step() - res = a.opt_grad(optimizer) + def test_matmul_3(self): + a = jt.arange(16).reshape(1, 4, 4).float() + d = jt.arange(8).reshape(1, 2, 4).float() + h = self.measure_time(lambda: jt.nn.matmul_transpose(a, d)) np.testing.assert_allclose( - res.numpy(), - [[[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 1, 1, 1]]]]) - print("test adaptive_maxpool2d_2 grad success") + h.numpy(), [[[14, 38], [38, 126], [62, 214], [86, 302]]]) + print("test matmul_3 success") @jt.flag_scope(use_acl=1) - def test_adaptive_avgpool2d(self): - a = jt.float32([[[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], - [13, 14, 15, 16]]]]) - pool_1 = jt.nn.AdaptiveAvgPool2d((2, 2)) - pool_2 = jt.nn.AdaptiveAvgPool2d((1, 3)) - b = pool_1(a) - c = pool_2(a) - np.testing.assert_allclose(b.numpy(), [[[[3.5, 5.5], [11.5, 13.5]]]]) - np.testing.assert_allclose(c.numpy(), [[[[7.5, 8.5, 9.5]]]]) - print("test adaptive_avgpool2d success") + def test_matmul_4(self): + a = jt.arange(16).reshape(1, 4, 4).float() + e = jt.arange(8).reshape(1, 4, 2).float() + i = self.measure_time(lambda: jt.matmul(a, e)) + np.testing.assert_allclose( + i.numpy(), [[[28, 34], [76, 98], [124, 162], [172, 226]]]) + print("test matmul_4 success") @jt.flag_scope(use_acl=1) - def test_adaptive_avgpool2d_grad(self): - a = jt.float32([[[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], - [13, 14, 15, 16]]]]) - avg_pool = jt.nn.AdaptiveAvgPool2d((2, 2)) - optimizer = jt.optim.SGD([a], 0.1) - b = avg_pool(a) - loss = b.sum() - optimizer.zero_grad() - optimizer.backward(loss) - optimizer.step() - res = a.opt_grad(optimizer) - np.testing.assert_allclose( - res.numpy(), - [[[[0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25], - [0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25]]]]) - print("test adaptive_avgpool2d grad success") - + def test_matmul_5(self): + b = jt.arange(8).reshape(4, 2).float() + c = jt.arange(8).reshape(2, 4).float() + j = self.measure_time(lambda: jt.matmul(b, c)) + np.testing.assert_allclose(j.numpy(), + [[4, 5, 6, 7], [12, 17, 22, 27], + [20, 29, 38, 47], [28, 41, 54, 67]]) + print("test matmul_5 success") + @jt.flag_scope(use_acl=1) - def test_adaptive_avgpool2d_grad_2(self): - a = jt.float32([[[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], - [13, 14, 15, 16]]]]) - avg_pool = jt.nn.AdaptiveAvgPool2d((1, 3)) - optimizer = jt.optim.SGD([a], 0.1) - b = avg_pool(a) - loss = b.sum() - optimizer.zero_grad() - optimizer.backward(loss) - optimizer.step() - res = a.opt_grad(optimizer) + def test_matmul_6(self): + b = jt.arange(8).reshape(4, 2).float() + bb = jt.arange(8).reshape(4, 2).float() + k = self.measure_time(lambda: jt.nn.matmul_transpose(b, bb)) np.testing.assert_allclose( - res.numpy(), - [[[[0.125, 0.25, 0.25, 0.125], [0.125, 0.25, 0.25, 0.125], - [0.125, 0.25, 0.25, 0.125], [0.125, 0.25, 0.25, 0.125]]]]) - print("test adaptive_avgpool2d_2 grad success") + k.numpy(), + [[1, 3, 5, 7], [3, 13, 23, 33], [5, 23, 41, 59], [7, 33, 59, 85]]) + print("test matmul_6 success") + + @jt.flag_scope(use_acl=1) + def test_grad_f_a(self): + a = jt.arange(16).reshape(1, 4, 4).float() + b = jt.arange(8).reshape(4, 2).float() + f = jt.matmul(a, b) + f_a = self.measure_time(lambda: jt.grad(f.sum(), a)) + np.testing.assert_allclose( + f_a.numpy(), + [[[1, 5, 9, 13], [1, 5, 9, 13], [1, 5, 9, 13], [1, 5, 9, 13]]]) + print("test grad_f_a success") + + @jt.flag_scope(use_acl=1) + def test_grad_f_b(self): + a = jt.arange(16).reshape(1, 4, 4).float() + b = jt.arange(8).reshape(4, 2).float() + f = jt.matmul(a, b) + f_b = self.measure_time(lambda: jt.grad(f.sum(), b)) + np.testing.assert_allclose(f_b.numpy(), + [[24, 24], [28, 28], [32, 32], [36, 36]]) + print("test grad_f_b success") + + @jt.flag_scope(use_acl=1) + def test_grad_g_a(self): + a = jt.arange(16).reshape(1, 4, 4).float() + c = jt.arange(8).reshape(2, 4).float() + g = jt.nn.matmul_transpose(a, c) + g_a = self.measure_time(lambda: jt.grad(g.sum(), a)) + np.testing.assert_allclose( + g_a.numpy(), + [[[4, 6, 8, 10], [4, 6, 8, 10], [4, 6, 8, 10], [4, 6, 8, 10]]]) + print("test grad_g_a success") + + @jt.flag_scope(use_acl=1) + def test_grad_g_c(self): + a = jt.arange(16).reshape(1, 4, 4).float() + c = jt.arange(8).reshape(2, 4).float() + g = jt.nn.matmul_transpose(a, c) + g_c = self.measure_time(lambda: jt.grad(g.sum(), c)) + np.testing.assert_allclose(g_c.numpy(), + [[24, 28, 32, 36], [24, 28, 32, 36]]) + print("test grad_g_c success") + + @jt.flag_scope(use_acl=1) + def test_grad_h_a(self): + a = jt.arange(16).reshape(1, 4, 4).float() + d = jt.arange(8).reshape(1, 2, 4).float() + h = jt.nn.matmul_transpose(a, d) + h_a = self.measure_time(lambda: jt.grad(h.sum(), a)) + np.testing.assert_allclose( + h_a.numpy(), + [[[4, 6, 8, 10], [4, 6, 8, 10], [4, 6, 8, 10], [4, 6, 8, 10]]]) + print("test grad_h_a success") + + @jt.flag_scope(use_acl=1) + def test_grad_h_d(self): + a = jt.arange(16).reshape(1, 4, 4).float() + d = jt.arange(8).reshape(1, 2, 4).float() + h = jt.nn.matmul_transpose(a, d) + h_d = self.measure_time(lambda: jt.grad(h.sum(), d)) + np.testing.assert_allclose(h_d.numpy(), + [[[24, 28, 32, 36], [24, 28, 32, 36]]]) + print("test grad_h_d success") + + @jt.flag_scope(use_acl=1) + def test_grad_i_a(self): + a = jt.arange(16).reshape(1, 4, 4).float() + e = jt.arange(8).reshape(1, 4, 2).float() + i = jt.matmul(a, e) + i_a = self.measure_time(lambda: jt.grad(i.sum(), a)) + np.testing.assert_allclose( + i_a.numpy(), + [[[1, 5, 9, 13], [1, 5, 9, 13], [1, 5, 9, 13], [1, 5, 9, 13]]]) + print("test grad_i_a success") + + @jt.flag_scope(use_acl=1) + def test_grad_i_e(self): + a = jt.arange(16).reshape(1, 4, 4).float() + e = jt.arange(8).reshape(1, 4, 2).float() + i = jt.matmul(a, e) + i_e = self.measure_time(lambda: jt.grad(i.sum(), e)) + np.testing.assert_allclose(i_e.numpy(), + [[[24, 24], [28, 28], [32, 32], [36, 36]]]) + print("test grad_i_e success") + + @jt.flag_scope(use_acl=1) + def test_grad_j_b(self): + b = jt.arange(8).reshape(4, 2).float() + c = jt.arange(8).reshape(2, 4).float() + j = jt.matmul(b, c) + j_b = self.measure_time(lambda: jt.grad(j.sum(), b)) + np.testing.assert_allclose(j_b.numpy(), + [[6, 22], [6, 22], [6, 22], [6, 22]]) + print("test grad_j_b success") + + @jt.flag_scope(use_acl=1) + def test_grad_j_c(self): + b = jt.arange(8).reshape(4, 2).float() + c = jt.arange(8).reshape(2, 4).float() + j = jt.matmul(b, c) + j_c = self.measure_time(lambda: jt.grad(j.sum(), c)) + np.testing.assert_allclose(j_c.numpy(), + [[12, 12, 12, 12], [16, 16, 16, 16]]) + print("test grad_j_c success") + + @jt.flag_scope(use_acl=1) + def test_grad_k_b(self): + b = jt.arange(8).reshape(4, 2).float() + bb = jt.arange(8).reshape(4, 2).float() + k = jt.nn.matmul_transpose(b, bb) + k_b = self.measure_time(lambda: jt.grad(k.sum(), b)) + np.testing.assert_allclose(k_b.numpy(), + [[12, 16], [12, 16], [12, 16], [12, 16]]) + + print("test grad_k_b success") + + @jt.flag_scope(use_acl=1) + def test_grad_k_bb(self): + b = jt.arange(8).reshape(4, 2).float() + bb = jt.arange(8).reshape(4, 2).float() + k = jt.nn.matmul_transpose(b, bb) + k_bb = self.measure_time(lambda: jt.grad(k.sum(), bb)) + np.testing.assert_allclose(k_bb.numpy(), + [[12, 16], [12, 16], [12, 16], [12, 16]]) + print("test grad_k_bb success") + + @jt.flag_scope(use_acl=1) + def test_bmm_matmul(self): + a = jt.arange(16).reshape(1, 4, 4).float() + b = jt.arange(8).reshape(1, 4, 2).float() + d = self.measure_time(lambda: jt.bmm(a, b)) + np.testing.assert_allclose( + d.numpy(), + [[[28, 34], [76, 98], [124, 162], [172, 226]]] + ) + print("test bmm_matmul success") + + @jt.flag_scope(use_acl=1) + def test_bmm_transpose(self): + a = jt.arange(16).reshape(1, 4, 4).float() + c = jt.arange(8).reshape(1, 2, 4).float() + e = self.measure_time(lambda: jt.nn.bmm_transpose(a, c)) + np.testing.assert_allclose( + e.numpy(), + [[[14, 38], [38, 126], [62, 214], [86, 302]]] + ) + print("test bmm_transpose success") + + @jt.flag_scope(use_acl=1) + def test_bmm_grad_a(self): + a = jt.arange(16).reshape(1, 4, 4).float() + b = jt.arange(8).reshape(1, 4, 2).float() + d = jt.bmm(a, b) + d_a = self.measure_time(lambda: jt.grad(d.sum(), a)) + np.testing.assert_allclose( + d_a.numpy(), + [[[1, 5, 9, 13], [1, 5, 9, 13], [1, 5, 9, 13], [1, 5, 9, 13]]] + ) + print("test bmm_grad_a success") + + @jt.flag_scope(use_acl=1) + def test_bmm_grad_b(self): + a = jt.arange(16).reshape(1, 4, 4).float() + b = jt.arange(8).reshape(1, 4, 2).float() + d = jt.bmm(a, b) + d_b = self.measure_time(lambda: jt.grad(d.sum(), b)) + np.testing.assert_allclose( + d_b.numpy(), + [[[24, 24], [28, 28], [32, 32], [36, 36]]] + ) + print("test bmm_grad_b success") + + @jt.flag_scope(use_acl=1) + def test_bmm_transpose_grad_a(self): + a = jt.arange(16).reshape(1, 4, 4).float() + c = jt.arange(8).reshape(1, 2, 4).float() + e = jt.nn.bmm_transpose(a, c) + e_a = self.measure_time(lambda: jt.grad(e.sum(), a)) + np.testing.assert_allclose( + e_a.numpy(), + [[[4, 6, 8, 10], [4, 6, 8, 10], [4, 6, 8, 10], [4, 6, 8, 10]]] + ) + print("test bmm_transpose_grad_a success") + + @jt.flag_scope(use_acl=1) + def test_bmm_transpose_grad_c(self): + a = jt.arange(16).reshape(1, 4, 4).float() + c = jt.arange(8).reshape(1, 2, 4).float() + e = jt.nn.bmm_transpose(a, c) + e_c = self.measure_time(lambda: jt.grad(e.sum(), c)) + np.testing.assert_allclose( + e_c.numpy(), + [[[24, 28, 32, 36], [24, 28, 32, 36]]]) + print("test bmm_transpose_grad_c success") @jt.flag_scope(use_acl=1) def test_index(self): - a = jt.rand(2, 3) - [s1, s2] = jt.index(a.shape) + a = jt.ones(2, 3) + [s1, s2] = self.measure_time(lambda: jt.index(a.shape)) np.testing.assert_allclose(s1.numpy(), [[0, 0, 0], [1, 1, 1]]) np.testing.assert_allclose(s2.numpy(), [[0, 1, 2], [0, 1, 2]]) print("test index success") @@ -306,158 +488,497 @@ class TestACL(unittest.TestCase): @jt.flag_scope(use_acl=1) def test_gather(self): a = jt.array([[1, 2], [3, 4]]) - b = jt.gather(a, 1, jt.array([[0, 0], [1, 0]])) - np.testing.assert_allclose(b.numpy(), [[1, 1], [4, 3]]) - b = jt.gather(a, 0, jt.array([[0, 0], [1, 0]])) - np.testing.assert_allclose(b.numpy(), [[1, 2], [3, 2]]) - b = jt.gather(a, -1, jt.array([[0, 0], [1, 0]])) + b = self.measure_time( + lambda: jt.gather(a, 1, jt.array([[0, 0], [1, 0]]))) np.testing.assert_allclose(b.numpy(), [[1, 1], [4, 3]]) print("test gather success") @jt.flag_scope(use_acl=1) def test_gather_grad(self): a = jt.float32([[1, 2], [3, 4]]) - optimizer = jt.optim.SGD([a], 0.1) - b = jt.gather(a, 0, jt.array([[0, 0], [1, 0]])) - loss = b.sum() - optimizer.zero_grad() - optimizer.backward(loss) - optimizer.step() - res = a.opt_grad(optimizer) - np.testing.assert_allclose(res.numpy(), [[1, 2], [1, 0]]) + b = jt.gather(a, 1, jt.array([[0, 0], [1, 0]])) + res = self.measure_time(lambda: jt.grad(b.sum(), a)) + np.testing.assert_allclose(res.numpy(), [[2, 0], [1, 1]]) print("test gather grad success") @jt.flag_scope(use_acl=1) - def test_gather_grad_neg(self): - a = jt.float32([[4, 3], [2, 1]]) - optimizer = jt.optim.SGD([a], 0.1) - b = jt.gather(a, -1, jt.array([[0, 0], [1, 0]])) - loss = b.sum() - optimizer.zero_grad() - optimizer.backward(loss) - optimizer.step() - res = a.opt_grad(optimizer) - np.testing.assert_allclose(res.numpy(), [[2, 0], [1, 1]]) - print("test gather grad neg success") + def test_cumsum_1(self): + a = jt.array([1, 2, 3, 4, 5]) + b = self.measure_time(lambda: jt.cumsum(a)) + np.testing.assert_allclose(b.numpy(), [1, 3, 6, 10, 15]) + print("test cumsum (test case 1) success") @jt.flag_scope(use_acl=1) - def test_scatter_add(self): + def test_cumsum_2(self): + a = jt.array([[1, 2, 3], [4, 5, 6]]) + b = self.measure_time(lambda: jt.cumsum(a, dim = 0)) + np.testing.assert_allclose(b.numpy(), [[1, 2, 3], [5, 7, 9]]) + print("test cumsum (test case 2) success") + + @jt.flag_scope(use_acl=1) + def test_cumsum_grad(self): + a = jt.array([[1., 2., 3.], [4., 5., 6.]]) + b = jt.cumsum(a, dim = 0) + res = self.measure_time(lambda: jt.grad(b.sum(), a)) + np.testing.assert_allclose(res.numpy(), [[2., 2., 2.], [1., 1., 1.]]) + print("test cumsum grad success") + + @jt.flag_scope(use_acl=1) + def test_any_1(self): + a = jt.array([[1, 0], [0, 4]]) + b = self.measure_time(lambda: jt.any(a)) + assert b.item() == True + print("test any (test case 1) success") + + @jt.flag_scope(use_acl=1) + def test_any_2(self): + a = jt.array([[1.0, 0.0]]) + b = self.measure_time(lambda: jt.any(a)) + assert b.item() == True + print("test any (test case 2) success") + + @jt.flag_scope(use_acl=1) + def test_any_3(self): + a = jt.array([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]) + b = self.measure_time(lambda: jt.any(a)) + assert b.item() == False + print("test any (test case 3) success") + + @jt.flag_scope(use_acl=1) + def test_any_4(self): + a = jt.array([[False, False, False], [False, False, False]]) + b = self.measure_time(lambda: jt.any(a)) + assert b.item() == False + print("test any (test case 4) success") + + @jt.flag_scope(use_acl=1) + def test_any_5(self): + a = jt.array([[False, True, False], [False, False, True], + [True, True, False]]) + b = self.measure_time(lambda: jt.any(a)) + assert b.item() == True + print("test any (test case 5) success") + + @jt.flag_scope(use_acl=1) + def test_any_6(self): + a = jt.array([[False, True, False], [False, False, True], + [True, True, False]]) + b = self.measure_time(lambda: a.any()) + assert b.item() == True + print("test any (test case 6) success") + + @jt.flag_scope(use_acl=1) + def test_any_7(self): + a = jt.array([[False, False, False], [False, False, True], + [True, True, False]]) + b = self.measure_time(lambda: jt.any(a, dim=1)) + assert (b.numpy() == [False, True, True]).all(), "Unexpected result for dim=1" + print("test any (test case 7) success") + + @jt.flag_scope(use_acl=1) + def test_any_8(self): + a = jt.array([[False, True, False], [False, False, True], + [False, True, False]]) + b = self.measure_time(lambda: jt.any(a, dim=0)) + assert (b.numpy() == [False, True, True]).all(), "Unexpected result for dim=0" + print("test any (test case 8) success") + + @jt.flag_scope(use_acl=1) + def test_any_9(self): + a = jt.array([[False, True, False], [False, False, True], + [False, True, False]]) + b = self.measure_time(lambda: a.any(dim=0)) + assert (b.numpy() == [False, True, True]).all(), "Unexpected result for dim=0" + print("test any (test case 9) success") + + @jt.flag_scope(use_acl=1) + def test_any_10(self): + # 测试在 dim=0 上检查每列是否有非零元素 + a = jt.array([[0, 1, 0], [0, 0, 0]]) + b = self.measure_time(lambda: jt.any(a, dim=0)) + assert (b.numpy() == [False, True, False]).all(), "Unexpected result for dim=0" + print("test any (test case 10) success") + + @jt.flag_scope(use_acl=1) + def test_any_11(self): + # 测试在 dim=0 上检查每列是否有非零元素 + a = jt.array([[0.0, 1.0, -1.0], [0, 0, 0]]) + b = self.measure_time(lambda: jt.any(a, dim=0)) + assert (b.numpy() == [False, True, True]).all(), "Unexpected result for dim=0" + print("test any (test case 11) success") + + @jt.flag_scope(use_acl=1) + def test_any_12(self): + # 测试在 dim=0 上检查每列是否有非零元素 + a = jt.array([[0.0, 1.0, -1.0], [0, 0, 0]]) + b = self.measure_time(lambda: jt.any(a, dim=1)) + assert (b.numpy() == [True, False]).all(), "Unexpected result for dim=0" + print("test any (test case 12) success") + + @jt.flag_scope(use_acl=1) + def test_scatter(self): a = jt.array([[1, 2], [3, 4]]) b = jt.array([[0, 0], [0, 0]]) - b = jt.scatter(b, 1, jt.array([[0, 0], [1, 0]]), a, reduce="add") - np.testing.assert_allclose(b.numpy(), [[3, 0], [4, 3]]) - print("test scatter add success") + c = self.measure_time(lambda: jt.scatter( + b, 1, jt.array([[0, 0], [1, 0]]), a, reduce="add")) + np.testing.assert_allclose(c.numpy(), [[45, 0], [60, 45]]) + print("test scatter success") @jt.flag_scope(use_acl=1) - def test_scatter_multi(self): - a = jt.array([[1, 2], [3, 4]]) - b = jt.array([[5, 6], [7, 8]]) - b = jt.scatter(b, 0, jt.array([[0, 0], [1, 0]]), a, reduce="multiply") - np.testing.assert_allclose(b.numpy(), [[5, 48], [21, 8]]) - print("test scatter multiply success") - - @jt.flag_scope(use_acl=1) - def test_scatter_add_grad(self): + def test_scatter_grad(self): a = jt.float32([[1, 2], [3, 4]]) b = jt.float32([[0, 0], [0, 0]]) - optimizer = jt.optim.SGD([a, b], 0.1) - c = jt.scatter(b, 1, jt.array([[0, 0], [1, 0]]), a, reduce="add") - loss = c.max() - optimizer.zero_grad() - optimizer.backward(loss) - optimizer.step() - res_a = a.opt_grad(optimizer) - res_b = b.opt_grad(optimizer) + + def tmp_func(a, b): + c = jt.scatter(b, 1, jt.array([[0, 0], [1, 0]]), a, reduce="add") + res_a = jt.grad(c.max(), a) + res_b = jt.grad(c.max(), b) + return res_a, res_b + + res_a, res_b = self.measure_time(lambda: tmp_func(a, b)) np.testing.assert_allclose(res_a.numpy(), [[0, 0], [0, 1]]) np.testing.assert_allclose(res_b.numpy(), [[0, 0], [1, 0]]) - print("test scatter add grad success") + print("test scatter grad success") @jt.flag_scope(use_acl=1) - def test_scatter_mult_grad(self): - a = jt.float32([[1, 2], [3, 4]]) - b = jt.float32([[5, 6], [7, 8]]) - optimizer = jt.optim.SGD([a, b], 0.1) - c = jt.scatter(b, 1, jt.array([[0, 0], [1, 0]]), a, reduce="multiply") - loss = c.max() - optimizer.zero_grad() - optimizer.backward(loss) - optimizer.step() - res_a = a.opt_grad(optimizer) - res_b = b.opt_grad(optimizer) - np.testing.assert_allclose(res_a.numpy(), [[0, 6], [0, 6]]) - np.testing.assert_allclose(res_b.numpy(), [[0, 8], [0, 0]]) - print("test scatter mult grad success") + def test_nonzero_1(self): + a = jt.array([[1, 0], [0, 4]]) + b = self.measure_time(lambda: a.nonzero()) + np.testing.assert_allclose(b.numpy(), [[0, 0], [1, 1]]) + print("test nonzero (test case 1) success") @jt.flag_scope(use_acl=1) - def test_where(self): + def test_nonzero_2(self): + a = jt.array([[1.0, 0.0], [0.0, 2.0]]) + b = self.measure_time(lambda: a.nonzero()) + np.testing.assert_allclose(b.numpy(), [[0, 0], [1, 1]]) + print("test nonzero (test case 2) success") + + @jt.flag_scope(use_acl=1) + def test_nonzero_3(self): + a = jt.array([[[True, False, True], [False, True, False]], + [[True, False, True], [False, True, False]]]) + b = self.measure_time(lambda: a.nonzero()) + np.testing.assert_allclose( + b.numpy(), + [[0, 0, 0], [0, 0, 2], [0, 1, 1], [1, 0, 0], [1, 0, 2], [1, 1, 1]]) + print("test nonzero (test case 3) success") + + @jt.flag_scope(use_acl=1) + def test_floor_int(self): + a = jt.array([[1.2, 0.0], [-0.1, 123.123]]) + b = self.measure_time(lambda: jt.floor_int(a)) + np.testing.assert_allclose(b.numpy(), [[1, 0], [-1, 123]]) + print("test floor_int success") + + @jt.flag_scope(use_acl=1) + def test_where_cond_expr(self): a = jt.array([[1, 2], [3, 4]]) b = jt.ones(2, 2) - c = jt.where(a > 2, a, b) + c = self.measure_time(lambda: jt.where(a > 2, a, b)) np.testing.assert_allclose(c.numpy(), [[1, 1], [3, 4]]) - print("test where success") - - @jt.flag_scope(use_acl=1) - def test_where_2(self): - a = jt.array([[1, 2], [3, 4]]) - b = jt.array([[5, 6], [7, 8]]) - cond = jt.array([[1, 0], [0, 1]]) - c = jt.where(cond, a, b) - np.testing.assert_allclose(c.numpy(), [[1, 6], [7, 4]]) - print("test where_2 success") + print("test where (cond expr) success") @jt.flag_scope(use_acl=1) def test_where_grad(self): - a = jt.array([[1, 2], [3, 4]]) - b = jt.array([[5, 6], [7, 8]]) - cond = jt.array([[1, 0], [0, 1]]) - c = jt.where(cond, a, b) - optimizer = jt.optim.SGD([a, b], 0.1) - loss = c.sum() - optimizer.zero_grad() - optimizer.backward(loss) - optimizer.step() - res_a = a.opt_grad(optimizer) - res_b = b.opt_grad(optimizer) + a = jt.float32([[1, 2], [3, 4]]) + b = jt.array([[2., 2.], [2., 2.]]) + c = jt.where(a > 2, a, b) + res_a = self.measure_time(lambda: jt.grad(c.sum(), a)) + res_b = self.measure_time(lambda: jt.grad(c.sum(), b)) np.testing.assert_allclose(res_a.numpy(), [[0, 0], [1, 1]]) np.testing.assert_allclose(res_b.numpy(), [[1, 1], [0, 0]]) print("test where grad success") @jt.flag_scope(use_acl=1) - def test_where_grad_2(self): - a = jt.float32([[1, 2], [3, 4]]) - b = jt.array([[2., 2.], [2., 2.]]) - c = jt.where(a > 2, a, b) - optimizer = jt.optim.SGD([a, b], 0.1) - loss = c.sum() - optimizer.zero_grad() - optimizer.backward(loss) - optimizer.step() - res_a = a.opt_grad(optimizer) - res_b = b.opt_grad(optimizer) - np.testing.assert_allclose(res_a.numpy(), [[0, 0], [1, 1]]) - np.testing.assert_allclose(res_b.numpy(), [[1, 1], [0, 0]]) - print("test where grad 2 success") + def test_where_unary_1(self): + a = jt.array([[1.0, 0.0], [0.0, 2.0]]) + b = self.measure_time(lambda: jt.where(a)) + # assert type(b) is tuple + assert len(b) == a.ndim + np.testing.assert_allclose(b[0].numpy(), [0, 1]) + np.testing.assert_allclose(b[1].numpy(), [0, 1]) + print("test where (unary) (test case 1) success") + + @jt.flag_scope(use_acl=1) + def test_where_unary_2(self): + a = jt.array([[1.0, -1.2], [0.13, 0.0]]) + b = self.measure_time(lambda: jt.where(a)) + # assert type(b) is tuple + assert len(b) == a.ndim + np.testing.assert_allclose(b[0].numpy(), [0, 0, 1]) + np.testing.assert_allclose(b[1].numpy(), [0, 1, 0]) + print("test where (unary) (test case 2) success") @jt.flag_scope(use_acl=1) def test_flip(self): a = jt.array([[1., 2.], [3., 4.]]) - b = a.flip((0, 1)) - np.testing.assert_allclose(b.numpy(), [[4, 3], [2, 1]]) + b = self.measure_time(lambda: a.flip()) + c = self.measure_time(lambda: a.flip(1)) + d = self.measure_time(lambda: a.flip((0, 1))) + np.testing.assert_allclose(b.numpy(), [[3, 4], [1, 2]]) + np.testing.assert_allclose(c.numpy(), [[2, 1], [4, 3]]) + np.testing.assert_allclose(d.numpy(), [[4, 3], [2, 1]]) print("test flip success") @jt.flag_scope(use_acl=1) def test_flip_grad(self): a = jt.float32([[1, 2], [3, 4]]) - optimizer = jt.optim.SGD([a], 0.1) b = a.flip((0, 1)) - loss = b.max() - optimizer.zero_grad() - optimizer.backward(loss) - optimizer.step() - res = a.opt_grad(optimizer) + res = self.measure_time(lambda: jt.grad(b.max(), a)) np.testing.assert_allclose(res.numpy(), [[0, 0], [0, 1]]) print("test flip grad success") + + @jt.flag_scope(use_acl=1) + def test_array(self): + a = self.measure_time(lambda: jt.array([1, 2, 3])) + np.testing.assert_allclose(a.numpy(), [1, 2, 3]) + print("test array success") + @jt.flag_scope(use_acl=1) + def test_add(self): + a = jt.array([1, 2, 3]) + b = self.measure_time(lambda: a + a) + np.testing.assert_allclose(b.numpy(), [2, 4, 6]) + print("test add success") + + @jt.flag_scope(use_acl=1) + def test_add_float(self): + a = jt.array([1.0, 2.0, 3.0]) + b = self.measure_time(lambda: a + a) + np.testing.assert_allclose(b.numpy(), [2, 4, 6]) + print("test add float success") + + @jt.flag_scope(use_acl=1) + def test_array_cast(self): + x = np.random.rand(10) + y = self.measure_time(lambda: jt.float32(x)) + np.testing.assert_allclose(x, y.numpy()) + print("test array cast success") + + @jt.flag_scope(use_acl=1) + def test_array_cast_half(self): + x = np.random.rand(10).astype("float32") + y = self.measure_time(lambda: jt.float16(x)) + np.testing.assert_allclose(x.astype("float16"), y.numpy()) + print("test array cast half success") + + @jt.flag_scope(use_acl=1) + def test_rand(self): + a = self.measure_time(lambda: jt.rand(10)) + b = self.measure_time(lambda: a * 10) + b.sync() + print("test rand success") + + @jt.flag_scope(use_acl=1) + def test_max(self): + x = jt.rand(3, 3) + y = self.measure_time(lambda: x.max(1)) + ny = x.data.max(1) + np.testing.assert_allclose(y.data, ny) + print("test max success") + + @jt.flag_scope(use_acl=1) + def test_sum(self): + x = jt.rand(3, 3).float16() + y = self.measure_time(lambda: x.sum(1)) + ny = x.data.sum(1) + np.testing.assert_allclose(y.data, ny) + print("test sum success") + + @jt.flag_scope(use_acl=1) + def test_broadcast(self): + x = jt.rand(3) + y = self.measure_time(lambda: x.broadcast([3, 3])) + with jt.flag_scope(use_acl=0): + ny = jt.broadcast(x, shape=(3, 3)).data + np.testing.assert_allclose(y.data, ny) + print("test broadcast success") + + @jt.flag_scope(use_acl=1) + def test_flashattention(self): + bsz = 1 + seq = 4 + headnum = 1 + headdim = 4 + xq = jt.ones(bsz,headnum,seq,headdim) + xk = jt.ones(bsz,headnum,seq,headdim) + xv = jt.ones(bsz,headnum,seq,headdim) + attention = jt.nn.FlashAttention(headnum,"BNSD") + xo = self.measure_time(lambda: attention(xq,xk,xv)) + np.testing.assert_allclose(xo.numpy(), + [[[[1., 1., 1., 1.], + [1., 1., 1., 1.], + [1., 1., 1., 1.], + [1., 1., 1., 1.]]]]) + print("test flashattention success") + + @jt.flag_scope(use_acl=1) + def test_flashattention_grad(self): + bsz = 1 + seq = 4 + headnum = 1 + headdim = 4 + xq = jt.ones(bsz,headnum,seq,headdim) + xk = jt.ones(bsz,headnum,seq,headdim) + xv = jt.ones(bsz,headnum,seq,headdim) + attention = jt.nn.FlashAttention(headnum,"BNSD") + xo = attention(xq,xk,xv) + dxq = self.measure_time(lambda: jt.grad(xo.max(), xq)) + dxk = self.measure_time(lambda: jt.grad(xo.max(), xk)) + dxv = self.measure_time(lambda: jt.grad(xo.max(), xv)) + np.testing.assert_allclose(dxq.numpy(), + [[[[0., 0., 0., 0.], + [0., 0., 0., 0.], + [0., 0., 0., 0.], + [0., 0., 0., 0.]]]]) + np.testing.assert_allclose(dxk.numpy(), + [[[[0., 0., 0., 0.], + [0., 0., 0., 0.], + [0., 0., 0., 0.], + [0., 0., 0., 0.]]]]) + np.testing.assert_allclose(dxv.numpy(), + [[[[1., 1., 1., 1.], + [1., 1., 1., 1.], + [1., 1., 1., 1.], + [1., 1., 1., 1.]]]]) + print("test flashattention grad success") + + @jt.flag_scope(use_acl=1) + def test_softmax(self): + a = jt.array([[1, 2], [3, 4]]) + res = self.measure_time(lambda: jt.nn.softmax(a, dim = -1)) + np.testing.assert_allclose(res.numpy(), [[0.26894143, 0.7310586], [0.26894143, 0.7310586]]) + print("test softmax success") + + @jt.flag_scope(use_acl=1) + def test_softmax_grad(self): + a = jt.float32([[1, 2], [3, 4]]) + b = jt.nn.softmax(a, dim = -1) + res = self.measure_time(lambda: jt.grad(b.max(), a)) + np.testing.assert_allclose(res.numpy(), [[-0.19661194, 0.19661193], [-0.19661194, 0.19661193]]) + print("test softmax grad success") + + @jt.flag_scope(use_acl=1) + def test_relu(self): + a = jt.array([[1, -2, 3], [-4, 5, -6]]) + res = self.measure_time(lambda: jt.nn.relu(a)) + np.testing.assert_allclose(res.numpy(), [[1, 0, 3], [0, 5, 0]]) + print("test relu success") + + @jt.flag_scope(use_acl=1) + def test_relu_grad(self): + a = jt.array([[1, -2, 3], [-4, 5, -6]]).float() + b = jt.nn.relu(a) + res = self.measure_time(lambda: jt.grad(b.max(), a)) + np.testing.assert_allclose(res.numpy(), [[0, 0, 0], [0, 1, 0]]) + print("test relu grad success") + + @jt.flag_scope(use_acl=1) + def test_silu(self): + a = jt.array([[1, 2, 3]]) + res = self.measure_time(lambda: jt.nn.silu(a)) + np.testing.assert_allclose(res.numpy(), [[0.7310586, 1.761594, 2.8577225]]) + print("test silu success") + + @jt.flag_scope(use_acl=1) + def test_silu_grad(self): + a = jt.float32([[1, 2, 3]]) + b = jt.nn.silu(a) + res = self.measure_time(lambda: jt.grad(b.max(), a)) + np.testing.assert_allclose(res.numpy(), [[0, 0, 1.0881041]]) + print("test silu grad success") + + @jt.flag_scope(use_acl=1) + def test_sigmoid(self): + a = jt.array([[1, 2, 3]]) + sig = jt.nn.Sigmoid() + res = self.measure_time(lambda: sig(a)) + np.testing.assert_allclose(res.numpy(), [[0.7310586, 0.880797, 0.95257413]]) + print("test sigmoid success") + + @jt.flag_scope(use_acl=1) + def test_sigmoid_grad(self): + a = jt.float32([[1, 2, 3]]) + sig = jt.nn.Sigmoid() + b = sig(a) + res = self.measure_time(lambda: jt.grad(b.sum(), a)) + np.testing.assert_allclose(res.numpy(), [[0.19661193, 0.1049936, 0.04517666]], rtol=1e-6, atol=1e-8) + print("test sigmoid grad success") + + @jt.flag_scope(use_acl=1) + def test_dropout(self): + jt.misc.set_global_seed(0) + x = jt.ones(3,3) + res = self.measure_time(lambda: jt.nn.dropout(x, is_train=True)) + np.testing.assert_allclose(res.numpy(),[[0, 2, 2],[0, 2, 0],[0, 2, 2]]) + print("test dropout success") + + @jt.flag_scope(use_acl=1) + def test_dropout_grad(self): + jt.misc.set_global_seed(0) + a = jt.ones(3,3) + b = jt.nn.dropout(a, is_train=True) + loss = b.sum() + res = self.measure_time(lambda: jt.grad(b.sum(), a)) + np.testing.assert_allclose(res.numpy(),[[1, 1, 1],[1, 1, 1],[1, 1, 1]]) + print("test dropout grad success") + + @jt.flag_scope(use_acl=1) + def test_leakyrelu(self): + a = jt.array([[1, -2, 3], [-4, 5, -6]]) + res = self.measure_time(lambda: jt.nn.leaky_relu(a)) + np.testing.assert_allclose(res.numpy(), [[1, -0.02, 3], [-0.04, 5, -0.06]]) + print("test leakyrelu success") + + @jt.flag_scope(use_acl=1) + def test_leakyrelu_grad(self): + a = jt.array([[1, -2, 3], [-4, 5, -6]]).float() + b = jt.nn.leaky_relu(a) + res = self.measure_time(lambda: jt.grad(b.max(), a)) + np.testing.assert_allclose(res.numpy(), [[0, 0, 0], [0, 1, 0]]) + print("test leakyrelu grad success") + + @jt.flag_scope(use_acl=1) + def test_embedding(self): + weight = jt.array([[0, 0, 3, 1], [2, 0, 3, 1], [0, 0, 0, 0]]) + input = jt.array([0, 2, 1]) + res = self.measure_time(lambda: jt.nn.embedding(input, weight)) + np.testing.assert_allclose(res.numpy(), [[0, 0, 3, 1], [0, 0, 0, 0], [2, 0, 3, 1]]) + print("test embedding success") + + # @jt.flag_scope(use_acl=1) + # def test_embedding_grad(self): + # a = jt.array([[0,0,3,1],[2,0,3,1],[0,0,0,0]]).float() + # input = jt.array([0,2,1]) + # b = jt.nn.embedding(input, a) + # res = self.measure_time(lambda: jt.grad(b.max(), a)) + # np.testing.assert_allclose(res.numpy(), [[0, 0, 0, 0], [0, 0, 1, 0], [0, 0, 0, 0]]) + # print("test embedding grad success") + + @jt.flag_scope(use_acl=1) + def test_stack(self): + a = jt.array([1, 2, 3]) + b = jt.array([4, 5, 6]) + c = self.measure_time(lambda: jt.stack([a, b])) + d = self.measure_time(lambda: jt.stack([a, b], dim = 1)) + np.testing.assert_allclose(c.numpy(), [[1, 2, 3], [4, 5, 6]]) + np.testing.assert_allclose(d.numpy(), [[1, 4], [2, 5], [3, 6]]) + print("test stack success") + + @jt.flag_scope(use_acl=1) + def test_is_nan(self): + x = jt.array([1.0, float('nan'), float('inf'), float('-inf'), -1.0, 2.0, 0.0]) + res = self.measure_time(lambda: jt.isnan(x)) + np.testing.assert_allclose(res.numpy(), [False, True, False, False, False, False, False]) + print("test is nan success") + + @jt.flag_scope(use_acl=1) + def test_is_inf(self): + x = jt.array([1.0, float('nan'), float('inf'), float('-inf'), -1.0, 2.0, 0.0]) + res = self.measure_time(lambda: jt.isinf(x)) + np.testing.assert_allclose(res.numpy(), [False, False, True, True, False, False, False]) + print("test is nan success") if __name__ == "__main__": - unittest.main() + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_complex.py b/python/jittor/test/test_complex.py index 4c1a6c53..430173e9 100644 --- a/python/jittor/test/test_complex.py +++ b/python/jittor/test/test_complex.py @@ -2,7 +2,6 @@ import jittor as jt from jittor.nn import ComplexNumber import unittest import numpy as np -from functools import partial __skip_torch_test = False try: @@ -11,15 +10,6 @@ except: __skip_torch_test = True class TestResultAndGrad: - def flatten_list(self, list_like): - results = [] - if isinstance(list_like, (list, tuple)): - for x in list_like: - results.extend(self.flatten_list(x)) - return results - else: - return [list_like] - def check_results(self, rlist1, rlist2): assert len(rlist1) == len(rlist2) for r1, r2 in zip(rlist1, rlist2): @@ -46,21 +36,13 @@ class TestResultAndGrad: grads.append(g.detach().cpu().numpy()) return grads - def run_jittor_op(self, op, input_list, weights=None, key_names=None, **kwargs): - def _np_to_jittor(x): - if isinstance(x, np.ndarray): - if x.dtype == np.complex64 or x.dtype == np.complex128: - nx = np.stack([np.real(x), np.imag(x)], axis=-1) - return ComplexNumber(jt.array(nx, dtype=jt.float32), is_concat_value=True) - elif x.dtype == np.float32 or x.dtype == np.float64: - return jt.array(x, dtype=jt.float32) - else: - assert False - elif isinstance(x, (list, tuple)): - nx = [_np_to_jittor(vx) for vx in x] - if isinstance(x, tuple): - return tuple(nx) - return nx + def run_jittor_op(self, op, input_list, weights=None): + def _np_to_jittor(x:np.ndarray): + if x.dtype == np.complex64 or x.dtype == np.complex128: + nx = np.stack([np.real(x), np.imag(x)], axis=-1) + return ComplexNumber(jt.array(nx, dtype=jt.float32), is_concat_value=True) + elif x.dtype == np.float32 or x.dtype == np.float64: + return jt.array(x, dtype=jt.float32) else: assert False def _jittor_to_np(x): @@ -69,19 +51,11 @@ class TestResultAndGrad: elif isinstance(x, ComplexNumber): return x.real.numpy() + 1j * x.imag.numpy() assert False - ninput_list = [_np_to_jittor(x) for x in input_list] - if key_names != None: - assert len(ninput_list) == len(key_names) - nkwargs = kwargs.copy() - for k, v in zip(key_names, ninput_list): - nkwargs[k] = v - output_list = op(**nkwargs) - else: - output_list = op(*ninput_list, **kwargs) + ninput_list = [_np_to_jittor(x) for x in input_list] + output_list = op(*ninput_list) if isinstance(output_list, (jt.Var, ComplexNumber)): output_list = [output_list] - output_list = self.flatten_list(output_list) losses = [] if weights is None: weights = [] @@ -99,31 +73,15 @@ class TestResultAndGrad: output_list = [_jittor_to_np(x) for x in output_list] return ninput_list, output_list, losses, weights - def run_torch_op(self, op, input_list, weights=None, key_names=None, **kwargs): - def _np_to_torch(x): - if isinstance(x, np.ndarray): - return torch.from_numpy(x).requires_grad_(True) - elif isinstance(x, (list, tuple)): - nx = [_np_to_torch(vx) for vx in x] - if isinstance(x, tuple): - return tuple(nx) - return nx - else: - assert False + def run_torch_op(self, op, input_list, weights=None): + def _np_to_torch(x:np.ndarray): + return torch.from_numpy(x).requires_grad_(True) def _torch_to_np(x:torch.Tensor) -> np.ndarray: return x.detach().cpu().numpy() ninput_list = [_np_to_torch(x) for x in input_list] - if key_names != None: - assert len(ninput_list) == len(key_names) - nkwargs = kwargs.copy() - for k, v in zip(key_names, ninput_list): - nkwargs[k] = v - output_list = op(**nkwargs) - else: - output_list = op(*ninput_list, **kwargs) + output_list = op(*ninput_list) if isinstance(output_list, torch.Tensor): output_list = [output_list] - output_list = self.flatten_list(output_list) losses = [] if weights is None: weights = [] @@ -141,10 +99,10 @@ class TestResultAndGrad: output_list = [_torch_to_np(x) for x in output_list] return ninput_list, output_list, losses, weights - def check_op_with_torch(self, jittor_op, torch_op, input_list, check_grad=True, jittor_knames=None, torch_knames=None, **kwargs): + def check_op_with_torch(self, jittor_op, torch_op, input_list, check_grad=True): weights = None - jittor_input, jittor_output, jittor_losses, weights = self.run_jittor_op(jittor_op, input_list, weights, key_names=jittor_knames, **kwargs) - torch_input, torch_output, torch_losses, weights = self.run_torch_op(torch_op, input_list, weights, key_names=torch_knames, **kwargs) + jittor_input, jittor_output, jittor_losses, weights = self.run_jittor_op(jittor_op, input_list, weights) + torch_input, torch_output, torch_losses, weights = self.run_torch_op(torch_op, input_list, weights) self.check_results(jittor_output, torch_output) if check_grad: @@ -237,249 +195,6 @@ class TestComplexLinalg(unittest.TestCase, TestResultAndGrad): inputs = [m1] self.check_op_with_numpy(jt.linalg.svd, np.linalg.svd, inputs) -class TestTensordot(unittest.TestCase, TestResultAndGrad): - def random_complex_matrix(self, shape): - r = np.random.randn(*shape) - i = np.random.randn(*shape) - return r + 1j * i - - def random_real_matrix(self, shape): - return np.random.randn(*shape) - - def test_complex_tensordot_numberdim(self): - s1 = (3, 4, 5) - s2 = (4, 5, 6) - dims = 2 - m1 = self.random_complex_matrix(s1) - m2 = self.random_complex_matrix(s2) - inputs = [m1, m2] - self.check_op_with_torch(jt.nn.tensordot, torch.tensordot, inputs, dims = dims) - - def test_complex_tensordot_tupledim(self): - s1 = (3, 5, 4, 6) - s2 = (6, 4, 5, 3) - dims = ([2, 1, 3], [1, 2, 0]) - m1 = self.random_complex_matrix(s1) - m2 = self.random_complex_matrix(s2) - inputs = [m1, m2] - self.check_op_with_torch(jt.nn.tensordot, torch.tensordot, inputs, dims = dims) - - def test_real_tensordot_numberdim(self): - s1 = (3, 4, 5) - s2 = (4, 5, 6) - dims = 2 - m1 = self.random_real_matrix(s1) - m2 = self.random_real_matrix(s2) - inputs = [m1, m2] - self.check_op_with_torch(jt.nn.tensordot, torch.tensordot, inputs, dims = dims) - - def test_real_tensordot_tupledim(self): - s1 = (3, 5, 4, 6) - s2 = (6, 4, 5, 3) - dims = ([2, 1, 3], [1, 2, 0]) - m1 = self.random_real_matrix(s1) - m2 = self.random_real_matrix(s2) - inputs = [m1, m2] - self.check_op_with_torch(jt.nn.tensordot, torch.tensordot, inputs, dims = dims) - -class TestKron(unittest.TestCase, TestResultAndGrad): - def random_complex_matrix(self, shape): - r = np.random.randn(*shape) - i = np.random.randn(*shape) - return r + 1j * i - - def random_real_matrix(self, shape): - return np.random.randn(*shape) - - def test_complex_firstlarge(self): - s1 = (2, 3, 4) - s2 = (5, 2) - m1 = self.random_complex_matrix(s1) - m2 = self.random_complex_matrix(s2) - inputs = [m1, m2] - self.check_op_with_torch(jt.nn.kron, torch.kron, inputs) - - def test_complex_second_large(self): - s1 = (2, 3) - s2 = (5, 2, 4) - m1 = self.random_complex_matrix(s1) - m2 = self.random_complex_matrix(s2) - inputs = [m1, m2] - self.check_op_with_torch(jt.nn.kron, torch.kron, inputs) - - def test_real_firstlarge(self): - s1 = (2, 3, 4) - s2 = (5, 2) - m1 = self.random_real_matrix(s1) - m2 = self.random_real_matrix(s2) - inputs = [m1, m2] - self.check_op_with_torch(jt.nn.kron, torch.kron, inputs) - - def test_real_second_large(self): - s1 = (2, 3) - s2 = (5, 2, 4) - m1 = self.random_real_matrix(s1) - m2 = self.random_real_matrix(s2) - inputs = [m1, m2] - self.check_op_with_torch(jt.nn.kron, torch.kron, inputs) - -@unittest.skipIf(__skip_torch_test, "No Torch found") -class TestGradFunctional(unittest.TestCase, TestResultAndGrad): - def random_complex_matrix(self, shape): - r = np.random.randn(*shape) - i = np.random.randn(*shape) - return r + 1j * i - - def random_real_matrix(self, shape): - return np.random.randn(*shape) * 0.0 + 1.0 - - def test_real_jvp_exp(self): - def exp_reducer(x): - return x.exp().sum(dim=1) - s1 = (5, 6) - m1 = self.random_real_matrix(s1) - m2 = self.random_real_matrix(s1) - inputs = [m1, m2] - self.check_op_with_torch( - partial(jt.gradfunctional.jvp, func=exp_reducer, create_graph=True), - partial(torch.autograd.functional.jvp, func=exp_reducer, create_graph=True), - inputs, - jittor_knames = ['inputs', 'v'], - torch_knames = ['inputs', 'v'], - check_grad=False) - - def test_complex_jvp_exp(self): - def exp_reducer(x): - return x.exp().sum(1) - s1 = (5, 6) - m1 = self.random_complex_matrix(s1) - m2 = self.random_complex_matrix(s1) - inputs = [m1, m2] - self.check_op_with_torch( - partial(jt.gradfunctional.jvp, func=exp_reducer, create_graph=True), - partial(torch.autograd.functional.jvp, func=exp_reducer, create_graph=True), - inputs, - jittor_knames = ['inputs', 'v'], - torch_knames = ['inputs', 'v'], - check_grad=False, - ) - - def test_real_jvp_add(self): - w1, w2 = np.random.rand(), np.random.rand() - def adder(x, y): - return w1 * x + w2 * y - s1 = (5, 6) - m1 = self.random_real_matrix(s1) - m2 = self.random_real_matrix(s1) - m3 = self.random_real_matrix(s1) - m4 = self.random_real_matrix(s1) - inputs = [(m1, m2), (m3, m4)] - self.check_op_with_torch( - partial(jt.gradfunctional.jvp, func=adder, create_graph=True), - partial(torch.autograd.functional.jvp, func=adder, create_graph=True), - inputs, - jittor_knames = ['inputs', 'v'], - torch_knames = ['inputs', 'v'], - check_grad=False, - ) - - def test_complex_jvp_add(self): - w1r, w1i = np.random.rand(), np.random.rand() - w2r, w2i = np.random.rand(), np.random.rand() - def adder_pt(x, y): - return (w1r + 1j * w1i) * x + (w2r + 1j * w2i) * y - def adder_jt(x, y): - w1 = ComplexNumber(real=jt.array(w1r).reshape(1,1), imag = jt.array(w1i).reshape(1,1)) - w2 = ComplexNumber(real=jt.array(w2r).reshape(1,1), imag = jt.array(w2i).reshape(1,1)) - return w1 * x + w2 * y - s1 = (5, 6) - m1 = self.random_complex_matrix(s1) - m2 = self.random_complex_matrix(s1) - m3 = self.random_complex_matrix(s1) - m4 = self.random_complex_matrix(s1) - inputs = [(m1, m2), (m3, m4)] - self.check_op_with_torch( - partial(jt.gradfunctional.jvp, func=adder_jt, create_graph=True), - partial(torch.autograd.functional.jvp, func=adder_pt, create_graph=True), - inputs, - jittor_knames = ['inputs', 'v'], - torch_knames = ['inputs', 'v'], - check_grad=False, - ) - - def test_real_vjp_exp(self): - def exp_reducer(x): - return x.exp().sum(dim=1) - s1 = (5, 6) - s2 = (5,) - m1 = self.random_real_matrix(s1) - m2 = self.random_real_matrix(s2) - inputs = [m1, m2] - self.check_op_with_torch( - partial(jt.gradfunctional.vjp, func=exp_reducer), - partial(torch.autograd.functional.vjp, func=exp_reducer), - inputs, - jittor_knames = ['inputs', 'v'], - torch_knames = ['inputs', 'v'], - check_grad=False) - - def test_complex_vjp_exp(self): - def exp_reducer(x): - return x.exp().sum(1) - s1 = (5, 6) - s2 = (5,) - m1 = self.random_real_matrix(s1) - m2 = self.random_real_matrix(s2) - inputs = [m1, m2] - self.check_op_with_torch( - partial(jt.gradfunctional.vjp, func=exp_reducer, create_graph=True), - partial(torch.autograd.functional.vjp, func=exp_reducer, create_graph=True), - inputs, - jittor_knames = ['inputs', 'v'], - torch_knames = ['inputs', 'v'], - check_grad=False, - ) - - def test_real_vjp_add(self): - w1, w2 = np.random.rand(), np.random.rand() - def adder(x, y): - return w1 * x + w2 * y - s1 = (5, 6) - m1 = self.random_real_matrix(s1) - m2 = self.random_real_matrix(s1) - m3 = self.random_real_matrix(s1) - inputs = [(m1, m2), m3] - self.check_op_with_torch( - partial(jt.gradfunctional.vjp, func=adder, create_graph=True), - partial(torch.autograd.functional.vjp, func=adder, create_graph=True), - inputs, - jittor_knames = ['inputs', 'v'], - torch_knames = ['inputs', 'v'], - check_grad=False, - ) - - def test_complex_vjp_add(self): - w1r, w1i = np.random.rand(), np.random.rand() - w2r, w2i = np.random.rand(), np.random.rand() - def adder_pt(x, y): - return (w1r + 1j * w1i) * x + (w2r + 1j * w2i) * y - def adder_jt(x, y): - w1 = ComplexNumber(real=jt.array(w1r).reshape(1,1), imag = jt.array(w1i).reshape(1,1)) - w2 = ComplexNumber(real=jt.array(w2r).reshape(1,1), imag = jt.array(w2i).reshape(1,1)) - return w1 * x + w2 * y - s1 = (5, 6) - m1 = self.random_complex_matrix(s1) - m2 = self.random_complex_matrix(s1) - m3 = self.random_complex_matrix(s1) - inputs = [(m1, m2), (m3)] - self.check_op_with_torch( - partial(jt.gradfunctional.vjp, func=adder_jt, create_graph=True), - partial(torch.autograd.functional.vjp, func=adder_pt, create_graph=True), - inputs, - jittor_knames = ['inputs', 'v'], - torch_knames = ['inputs', 'v'], - check_grad=False, - ) if __name__ == "__main__": - unittest.main() + unittest.main() \ No newline at end of file