diff --git a/python/jittor/extern/acl/acl_compiler.py b/python/jittor/extern/acl/acl_compiler.py index 128fd612..caf154b9 100644 --- a/python/jittor/extern/acl/acl_compiler.py +++ b/python/jittor/extern/acl/acl_compiler.py @@ -158,33 +158,8 @@ def acl_cmd(name: str, attr_code: str = "", attr_header: str = "", outputs: list = None): - # inputs: list, - # output_dtypes: list, - # output_shapes: list, - # attr_code: str = ""): - # input_code = '' - # for i in range(len(inputs)): - # input_code += f"op.add(in{i}, true);\n" - - # output_code = '' - # for i in range(len(output_dtypes)): - # output_code += f"op.add(out{i}, false);\n" - - # # read the tmp_file.cpp to the cuda_header - # with open( - # "/home/ma-user/work/zy/JittorHW/python/jittor/extern/acl/tmp_file.cpp", - # "r") as f: - # cuda_header = f.read() - # import jittor as jt - # return jt.code(output_shapes, - # output_dtypes, - # inputs, - # cuda_header=cuda_header, - # cuda_src=f""" attr_header = "\nnamespace jittor{" + attr_header + "}\n" - # read the tmp_file.cpp to the cuda_header - cuda_header = '#include "acl/aclops/aclops.h"' import jittor as jt outputs_ = [] @@ -194,11 +169,10 @@ def acl_cmd(name: str, assert output_dtypes is not None assert output_shapes is not None assert len(output_dtypes) == len(output_shapes) - # print(f'{name } output_dtypes', output_dtypes) - # print(f'{name } output_shapes', output_shapes) + for i in range(len(output_shapes)): outputs_.append(jt.empty(output_shapes[i], output_dtypes[i])) - # print(f'{name } outputs_', outputs_) + input_code = '' for i in range(len(inputs)): input_code += f"op.add(in{i}, true);\n" @@ -240,11 +214,9 @@ def acl_cmd_forward(name: str, assert output_dtypes is not None assert output_shapes is not None assert len(output_dtypes) == len(output_shapes) - # print(f'{name } output_dtypes', output_dtypes) - # print(f'{name } output_shapes', output_shapes) for i in range(len(output_shapes)): outputs_.append(jt.empty(output_shapes[i], output_dtypes[i])) - # print(f'{name } outputs_', outputs_) + input_code = '' for i in range(len(inputs)): input_code += f"op.add(in{i}, true);\n" @@ -269,6 +241,31 @@ def acl_cmd_forward(name: str, def change_function(): import jittor as jt from jittor import Function + from .aclops.flashattention_op import FlashAttentionACL + from .aclops.conv_op import ConvACL + from .aclops.pool_op import PoolACL + from .aclops.nantonum_op import NanToNumACL + from .aclops.stack_op import StackACL + from .aclops.rope_op import RopeACL + from .aclops.softmax_op import SoftmaxACL + from .aclops.sigmoid_op import SigmoidACL + from .aclops.silu_op import SiLUACL + from .aclops.dropout_op import DropoutACL + from .aclops.relu_op import LeakyReLUACL + from .aclops.flip_op import FlipACL + from .aclops.concat_op import ConcatACL + from .aclops.gather_scatter_op import GatherACL + from .aclops.cumsum_op import CumsumACL + from .aclops.index_op import IndexACL + from .aclops.gather_scatter_op import ScatterACL + from .aclops.where_op import WhereACL + from .aclops.where_op import NonzeroACL + from .aclops.floor_op import FloorIntACL + from .aclops.getitem_op import GetItemACL + from .aclops.setitem_op import SetItemACL + from .aclops.bmm_op import BmmACL + from .aclops.matmul_op import MatmulACL + from .aclops.transpose_op import TransPoseACL from .aclops.triu_op import TriuACL @@ -276,6 +273,7 @@ def change_function(): return TriuACL()(x, diagonal) from .aclops.conv_op import ConvACL + def conv_acl(x, weight, bias=None, @@ -385,6 +383,7 @@ def change_function(): self.padding, self.dilation, self.groups) return ret + from .aclops.flip_op import FlipACL def flip_acl(x, dim): return FlipACL()(x, dim) @@ -441,6 +440,7 @@ def change_function(): return FloorIntACL()(x) from .aclops.getitem_op import GetItemACL + def getitem_acl(x, slices, return_x=None): # Transform numpy int to int if isinstance(slices, (np.int8, np.int16, np.int32, np.int64)): @@ -489,19 +489,25 @@ def change_function(): return result + from .aclops.setitem_op import SetItemACL + def setitem_acl(x, slices, value): res = SetItemACL()(x, slices, value) return x.assign(res) + from .aclops.bmm_op import BmmACL + def bmm_acl(x1, x2): return BmmACL()(x1, x2) def bmm_transpose_acl(x1, x2): return BmmACL(True)(x1, x2) + from .aclops.matmul_op import MatmulACL + def matmul_acl(x1, x2): return MatmulACL()(x1, x2) @@ -512,6 +518,7 @@ def change_function(): def transpose_acl(x, *dim): return TransPoseACL()(x, *dim) + class ReLUACL(Function): def __init__(self): @@ -550,6 +557,7 @@ def change_function(): return ReLUACL()(x) from .aclops.relu_op import LeakyReLUACL + class LeakyReLU(jt.nn.Module): def __init__(self, negative_slope=0.01): @@ -563,6 +571,7 @@ def change_function(): return LeakyReLUACL()(x, scale) from .aclops.dropout_op import DropoutACL + class Dropout(jt.nn.Module): def __init__(self, p=0.5, is_train=False): @@ -620,8 +629,7 @@ def change_function(): # def execute(self, x): # res = embedding_acl(x, self.weight) # return res - - from .aclops.softmax_op import SoftmaxACL + class Softmax(jt.nn.Module): def __init__(self): @@ -642,7 +650,7 @@ def change_function(): return StackACL()(x, dim) from .aclops.nantonum_op import NanToNumACL - + def isnan_acl(x): tonum = NanToNumACL()(x, -1.0) return jt.not_equal(x, tonum).logical_and( @@ -694,8 +702,7 @@ def change_function(): jt.nn.conv2d = warp(jt.nn.conv2d, conv_acl) jt.nn.Conv2d = warp(jt.nn.Conv2d, Conv2D) jt.nn.Conv = warp(jt.nn.Conv, Conv2D) - - from .aclops.pool_op import PoolACL + jt.nn.Pool = warp(jt.nn.Pool, PoolACL) jt.flip = warp(jt.flip, flip_acl) @@ -743,6 +750,7 @@ def change_function(): jt.Var.__setitem__ = lambda x, slices, value: warp( fake_setitem, setitem_acl, name='setitem')(x, slices, value) + fake_matmul = jt.Var.matmul jt.nn.bmm = warp(jt.nn.bmm, bmm_acl) jt.bmm = warp(jt.bmm, bmm_acl) jt.nn.matmul = warp(jt.matmul, matmul_acl) @@ -750,6 +758,7 @@ def change_function(): jt.nn.matmul_transpose = warp(jt.nn.matmul_transpose, matmul_transpose_acl) jt.nn.bmm_transpose = warp(jt.nn.bmm_transpose, bmm_transpose_acl) jt.bmm_transpose = warp(jt.bmm_transpose, bmm_transpose_acl) + jt.Var.__matmul__ = lambda x, y: warp(fake_matmul, matmul_acl)(x, y) jt.transpose = warp(jt.transpose, transpose_acl) fake_transpose = jt.transpose @@ -784,7 +793,7 @@ def change_function(): # from .aclops.norms_op import BatchNormACL,LayerNormACL # jt.nn.BatchNorm = warp(jt.nn.BatchNorm, BatchNormACL) # jt.nn.LayerNorm = warp(jt.nn.LayerNorm, LayerNormACL) - from .aclops.flashattention_op import FlashAttentionACL + jt.nn.FlashAttention = warp(jt.nn.FlashAttention, FlashAttentionACL) jt.isnan = warp(jt.isnan, isnan_acl) jt.isinf = warp(jt.isinf, isinf_acl) diff --git a/python/jittor/extern/acl/acl_op_exec.cc b/python/jittor/extern/acl/acl_op_exec.cc index 9e295775..727f8349 100644 --- a/python/jittor/extern/acl/acl_op_exec.cc +++ b/python/jittor/extern/acl/acl_op_exec.cc @@ -154,9 +154,9 @@ namespace jittor std::queue queue; for (Op *op : fop->ops) - op_indeg[op] = 0; + op_indeg[op] = 0; - map> out_map; + map> out_map; map> from; int len = 0; @@ -303,15 +303,12 @@ namespace jittor op.op_attr.reset(attr); op.add(rop->y, false); op.run(); + aclrtSynchronizeStream(aclstream); } else if (op->name() == string("broadcast_to")) { auto bop = (BroadcastToOp *)op; AclOpRunner op("Expand"); - if (bop->x->shape.size() == 1 && bop->x->shape[0] == 1) - { - aclrtSynchronizeStream(aclstream); - } op.jt_name = "expand"; NanoVector xshape, xshape_bk = bop->x->shape; NanoVector zshape = bop->z->shape; @@ -333,13 +330,13 @@ namespace jittor op.add(bop->z, false); op.run(); bop->x->shape = xshape_bk; - // aclrtSynchronizeStream(aclstream); + aclrtSynchronizeStream(aclstream); } else if (op->name() == string("fuse_transpose")) { // replace fuse_transpose with transpose auto top = (TransposeOp *)op; - AclOpRunner op("Transpose"); + TransposeOpRunner op; op.add(top->x, true); op.add(top->y, false); op.jt_name = "transpose"; @@ -497,4 +494,4 @@ namespace jittor } } -} // jittor \ No newline at end of file +} // jittor diff --git a/python/jittor/extern/acl/aclops/acl_op.h b/python/jittor/extern/acl/aclops/acl_op.h index 473bebd6..1e8c2352 100644 --- a/python/jittor/extern/acl/aclops/acl_op.h +++ b/python/jittor/extern/acl/aclops/acl_op.h @@ -274,7 +274,7 @@ namespace jittor // ret = aclrtSynchronizeStream(aclstream); // CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclrtSynchronizeStream failed. ERROR: %d\n", name.c_str(), ret); return); // } - + // 6. 释放aclTensor和aclScalar,需要根据具体API的接口定义修改 // destroy tensor // for (int idx = 0; idx < input_num; idx++) diff --git a/python/jittor/extern/acl/aclops/base_op.h b/python/jittor/extern/acl/aclops/base_op.h index fcf1e752..93eef363 100644 --- a/python/jittor/extern/acl/aclops/base_op.h +++ b/python/jittor/extern/acl/aclops/base_op.h @@ -33,19 +33,20 @@ namespace jittor // Common functionality for adding input/output variables void add(Var *v, bool is_input); - + virtual void setupInputDesc(); - + void cleanupDesc(); virtual void setupOutputDesc(); - + virtual void syncRun(); void checkRet(aclnnStatus ret); - + // Base run method with common operator lookup logic void run(); + protected: // Virtual method for specific operator execution virtual void executeOp(std::unordered_map::iterator &it) = 0; diff --git a/python/jittor/extern/acl/aclops/base_op_acl.cc b/python/jittor/extern/acl/aclops/base_op_acl.cc index 3b2101a7..900300c7 100644 --- a/python/jittor/extern/acl/aclops/base_op_acl.cc +++ b/python/jittor/extern/acl/aclops/base_op_acl.cc @@ -105,13 +105,13 @@ namespace jittor void BaseOpRunner::syncRun() { - if(sync_run) { - ret = aclrtSynchronizeStream(aclstream); - CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclrtSynchronizeStream failed. ERROR: %d\n", name.c_str(), ret); return); + if (sync_run) + { + // ret = aclrtSynchronizeStream(aclstream); + // CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclrtSynchronizeStream failed. ERROR: %d\n", name.c_str(), ret); return); } } - void BaseOpRunner::checkRet(aclnnStatus ret) { if (ret != ACL_SUCCESS) diff --git a/python/jittor/extern/acl/aclops/binary_op_acl.cc b/python/jittor/extern/acl/aclops/binary_op_acl.cc index 14f6382e..18142491 100644 --- a/python/jittor/extern/acl/aclops/binary_op_acl.cc +++ b/python/jittor/extern/acl/aclops/binary_op_acl.cc @@ -116,7 +116,7 @@ namespace jittor ret = it->second.executeFunc(workspaceAddr, workspaceSize, executor, aclstream); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnxxx failed. ERROR: %d\n", name.c_str(), ret); return); - //syncRun(); + syncRun(); aclDestroyScalar(alpha); return; diff --git a/python/jittor/extern/acl/aclops/bmm_op.py b/python/jittor/extern/acl/aclops/bmm_op.py index 7b1c3891..78fee49c 100644 --- a/python/jittor/extern/acl/aclops/bmm_op.py +++ b/python/jittor/extern/acl/aclops/bmm_op.py @@ -12,15 +12,14 @@ from typing import Union from collections.abc import Sequence, Iterable - def acl_cmd(name: str, - inputs: list, - output_dtypes: list = None, - output_shapes: list = None, - attr_code: str = "", - attr_header: str = "", - outputs: list = None, - extra_data: dict = {}): + inputs: list, + output_dtypes: list = None, + output_shapes: list = None, + attr_code: str = "", + attr_header: str = "", + outputs: list = None, + extra_data: dict = {}): attr_header = "\nnamespace jittor{" + attr_header + "}\n" cuda_header = ''' @@ -38,7 +37,7 @@ def acl_cmd(name: str, input_code = '' for i in range(len(inputs)): input_code += f"op.add(in{i}, true);\n" - + return jt.code(outputs=outputs_, inputs=inputs, cuda_header=attr_header + cuda_header, @@ -49,8 +48,9 @@ def acl_cmd(name: str, op.add(out0, false); {attr_code} op.run();""", - data=extra_data) - + data=extra_data) + + class BmmACL(jt.Function): def __init__(self, trans_x2=False): @@ -59,16 +59,14 @@ class BmmACL(jt.Function): def execute(self, x1, x2): self.input = [x1, x2] - result = acl_cmd( - "BatchMatMul", [x1, x2], - output_dtypes=[x1.dtype], - output_shapes=[ - x1.shape[:-1] + - x2.shape[-2:-1] if self.trans_x2 else x1.shape[:-1] + - x2.shape[-1:] - ], - attr_code="op.jt_name=\"bmm_trans_1\";" - if self.trans_x2 else "op.jt_name=\"bmm\";")[0] + result = acl_cmd("BatchMatMul", [x1, x2], + output_dtypes=[x1.dtype], + output_shapes=[ + x1.shape[:-1] + x2.shape[-2:-1] if self.trans_x2 + else x1.shape[:-1] + x2.shape[-1:] + ], + attr_code="op.jt_name=\"bmm_trans_1\";" + if self.trans_x2 else "op.jt_name=\"bmm\";")[0] return result @@ -78,57 +76,53 @@ class BmmACL(jt.Function): reshape_grad_x2 = True else: reshape_grad_x2 = False - grad_x1 = acl_cmd("BatchMatMul", [grad_output, x2], - output_dtypes=[x1.dtype], - output_shapes=[ - grad_output.shape[:-1] + - x2.shape[-2:-1] if not self.trans_x2 else - grad_output.shape[:-1] + x1.shape[-1:] - ], - attr_code="op.jt_name=\"bmm_trans_1\";" if - not self.trans_x2 else "op.jt_name=\"bmm\";")[0] + grad_x1 = acl_cmd( + "BatchMatMul", [grad_output, x2], + output_dtypes=[x1.dtype], + output_shapes=[ + grad_output.shape[:-1] + x2.shape[-2:-1] if not self.trans_x2 + else grad_output.shape[:-1] + x1.shape[-1:] + ], + attr_code="op.jt_name=\"bmm_trans_1\";" + if not self.trans_x2 else "op.jt_name=\"bmm\";")[0] if self.trans_x2: if reshape_grad_x2: output_shape = grad_output.shape[1:-2] + grad_output.shape[ -1:] + x1.shape[-1:] - grad_x2 = acl_cmd( - "BatchMatMul", [ - grad_output.reshape(-1, grad_output.shape[-1]), - x1.reshape(-1, x1.shape[-1]) - ], - output_dtypes=[x2.dtype], - output_shapes=[output_shape], - attr_code="op.jt_name=\"bmm_trans_0\";")[0] + grad_x2 = acl_cmd("BatchMatMul", [ + grad_output.reshape(-1, grad_output.shape[-1]), + x1.reshape(-1, x1.shape[-1]) + ], + output_dtypes=[x2.dtype], + output_shapes=[output_shape], + attr_code="op.jt_name=\"bmm_trans_0\";")[0] else: output_shape = grad_output.shape[:-2] + grad_output.shape[ -1:] + x1.shape[-1:] - grad_x2 = acl_cmd( - "BatchMatMul", [grad_output, x1], - output_dtypes=[x2.dtype], - output_shapes=[output_shape], - attr_code="op.jt_name=\"bmm_trans_0\";")[0] + grad_x2 = acl_cmd("BatchMatMul", [grad_output, x1], + output_dtypes=[x2.dtype], + output_shapes=[output_shape], + attr_code="op.jt_name=\"bmm_trans_0\";")[0] else: if reshape_grad_x2: output_shape = x1.shape[1:-2] + x1.shape[ -1:] + grad_output.shape[-1:] - grad_x2 = acl_cmd( - "BatchMatMul", [ - x1.reshape(-1, x1.shape[-1]), - grad_output.reshape(-1, grad_output.shape[-1]) - ], - output_dtypes=[x2.dtype], - output_shapes=[output_shape], - attr_code="op.jt_name=\"bmm_trans_0\";")[0] + grad_x2 = acl_cmd("BatchMatMul", [ + x1.reshape(-1, x1.shape[-1]), + grad_output.reshape(-1, grad_output.shape[-1]) + ], + output_dtypes=[x2.dtype], + output_shapes=[output_shape], + attr_code="op.jt_name=\"bmm_trans_0\";")[0] else: output_shape = x1.shape[:-2] + x1.shape[ -1:] + grad_output.shape[-1:] - grad_x2 = acl_cmd( - "BatchMatMul", [x1, grad_output], - output_dtypes=[x2.dtype], - output_shapes=[output_shape], - attr_code="op.jt_name=\"bmm_trans_0\";")[0] + grad_x2 = acl_cmd("BatchMatMul", [x1, grad_output], + output_dtypes=[x2.dtype], + output_shapes=[output_shape], + attr_code="op.jt_name=\"bmm_trans_0\";")[0] if len(grad_x1.shape) > len(x1.shape): grad_x1 = grad_x1.sum(0) if len(grad_x2.shape) > len(x2.shape): grad_x2 = grad_x2.sum(0) - return grad_x1, grad_x2 \ No newline at end of file + return grad_x1, grad_x2 diff --git a/python/jittor/extern/acl/aclops/bmm_op_acl.cc b/python/jittor/extern/acl/aclops/bmm_op_acl.cc index 82825212..b75f9177 100644 --- a/python/jittor/extern/acl/aclops/bmm_op_acl.cc +++ b/python/jittor/extern/acl/aclops/bmm_op_acl.cc @@ -63,8 +63,8 @@ namespace jittor } void BatchMatMulOpRunner::executeOp(std::unordered_map::iterator &it) { - - ret = aclnnBatchMatMulGetWorkspaceSize(inputTensors[0], inputTensors[1], outputTensors[0], 1, &workspaceSize, &executor); + + ret = aclnnBatchMatMulGetWorkspaceSize(inputTensors[0], inputTensors[1], outputTensors[0], 1, &workspaceSize, &executor); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnBatchMatmulGetWorkspaceSize failed. ERROR: %d\n", name.c_str(), ret); return); if (workspaceSize > 0) { @@ -72,6 +72,6 @@ namespace jittor } ret = aclnnBatchMatMul(workspaceAddr, workspaceSize, executor, aclstream); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnbatchMatmul failed. ERROR: %d\n", name.c_str(), ret); return); - // syncRun(); + syncRun(); } } \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/bmm_op_acl.h b/python/jittor/extern/acl/aclops/bmm_op_acl.h index 1cf11457..283bf5fc 100644 --- a/python/jittor/extern/acl/aclops/bmm_op_acl.h +++ b/python/jittor/extern/acl/aclops/bmm_op_acl.h @@ -10,6 +10,7 @@ namespace jittor protected: void setupInputDesc() override; void executeOp(std::unordered_map::iterator &it) override; + public: BatchMatMulOpRunner(); }; diff --git a/python/jittor/extern/acl/aclops/concat_op.py b/python/jittor/extern/acl/aclops/concat_op.py index faca9d77..bd02f9e9 100644 --- a/python/jittor/extern/acl/aclops/concat_op.py +++ b/python/jittor/extern/acl/aclops/concat_op.py @@ -11,13 +11,14 @@ import numpy as np from typing import Union from collections.abc import Sequence, Iterable + def concat_cmd(name: str, - inputs: list, - output_dtypes: list = None, - output_shapes: list = None, - attr_code: str = "", - attr_header: str = "", - outputs: list = None): + inputs: list, + output_dtypes: list = None, + output_shapes: list = None, + attr_code: str = "", + attr_header: str = "", + outputs: list = None): attr_header = "\nnamespace jittor{" + attr_header + "}\n" cuda_header = ''' @@ -51,6 +52,7 @@ def concat_cmd(name: str, {attr_code} op.run();""") + class ConcatACL(jt.Function): def __init__(self): @@ -114,13 +116,11 @@ class ConcatACL(jt.Function): self.dim = dim for i in range(len(input_tensors)): if input_tensors[i].dtype != input_tensors[0].dtype: - raise ValueError( - "All input tensors must have the same dtype") + raise ValueError("All input tensors must have the same dtype") if input_tensors[i].shape[:dim] != input_tensors[ 0].shape[:dim] or input_tensors[i].shape[ dim + 1:] != input_tensors[0].shape[dim + 1:]: - raise ValueError( - "All input tensors must have the same shape") + raise ValueError("All input tensors must have the same shape") attr_code = f""" op.jt_name = "concat"; ConcatAttr *attr = new ConcatAttr(); @@ -133,15 +133,13 @@ class ConcatACL(jt.Function): input_tensors, output_dtypes=[input_tensors[0].dtype], output_shapes=[ - jt.empty(self.calculate_output_shape(input_tensors, - dim)).shape + jt.empty(self.calculate_output_shape(input_tensors, dim)).shape ], attr_code=attr_code)[0] return result def _grad(self, *args): - new_args = ((args[i] if i >= 0 else None) - for i in self.output_mask) + new_args = ((args[i] if i >= 0 else None) for i in self.output_mask) ret = self.grad(*new_args) new_ret = [] for i, r in enumerate(ret): @@ -185,4 +183,4 @@ class ConcatACL(jt.Function): output_dtypes=dtypeVec, output_shapes=shapeVec, attr_code=attr_code) - return result \ No newline at end of file + return result diff --git a/python/jittor/extern/acl/aclops/concat_op_acl.h b/python/jittor/extern/acl/aclops/concat_op_acl.h index 3e2df4eb..a051e343 100644 --- a/python/jittor/extern/acl/aclops/concat_op_acl.h +++ b/python/jittor/extern/acl/aclops/concat_op_acl.h @@ -9,6 +9,7 @@ namespace jittor protected: void executeOp(std::unordered_map::iterator &it) override; + public: ConcatOpRunner(); }; @@ -18,6 +19,7 @@ namespace jittor protected: void executeOp(std::unordered_map::iterator &it) override; + public: SplitWithSizeOpRunner(); }; diff --git a/python/jittor/extern/acl/aclops/conv_op.py b/python/jittor/extern/acl/aclops/conv_op.py index 5bcd71b8..e91c3c98 100644 --- a/python/jittor/extern/acl/aclops/conv_op.py +++ b/python/jittor/extern/acl/aclops/conv_op.py @@ -22,16 +22,18 @@ def _ntuple(n): return parse + _pair = _ntuple(2) + def conv_forward(name: str, - inputs: list, - output_dtypes: list = None, - output_shapes: list = None, - attr_code: str = "", - attr_header: str = "", - outputs: list = None, - extra_data: dict = {}): + inputs: list, + output_dtypes: list = None, + output_shapes: list = None, + attr_code: str = "", + attr_header: str = "", + outputs: list = None, + extra_data: dict = {}): attr_header = "\nnamespace jittor{" + attr_header + "}\n" cuda_header = ''' @@ -49,7 +51,7 @@ def conv_forward(name: str, input_code = '' for i in range(len(inputs)): input_code += f"op.add(in{i}, true);\n" - + return jt.code(outputs=outputs_, inputs=inputs, cuda_header=attr_header + cuda_header, @@ -60,16 +62,17 @@ def conv_forward(name: str, op.add(out0, false); {attr_code} op.run();""", - data=extra_data) + data=extra_data) + def conv_forward(name: str, - inputs: list, - output_dtypes: list = None, - output_shapes: list = None, - attr_code: str = "", - attr_header: str = "", - outputs: list = None, - extra_data: dict = {}): + inputs: list, + output_dtypes: list = None, + output_shapes: list = None, + attr_code: str = "", + attr_header: str = "", + outputs: list = None, + extra_data: dict = {}): # TODO: not done for now attr_header = "\nnamespace jittor{" + attr_header + "}\n" @@ -88,7 +91,7 @@ def conv_forward(name: str, input_code = '' for i in range(len(inputs)): input_code += f"op.add(in{i}, true);\n" - + return jt.code(outputs=outputs_, inputs=inputs, cuda_header=attr_header + cuda_header, @@ -99,32 +102,33 @@ def conv_forward(name: str, op.add(out0, false); {attr_code} op.run();""", - data=extra_data) + data=extra_data) + class ConvACL(jt.Function): - def execute(self, - x, - weight, - bias=None, - stride=1, - padding=0, - dilation=1, - groups=1): - self.input = x - self.weight = weight - self.bias = bias - padding = _pair(padding) - stride = _pair(stride) - dilation = _pair(dilation) - out_channels = weight.shape[0] - if groups <= 0: - raise ValueError("groups must be a positive integer") - self.padding = padding - self.stride = stride - self.dilation = dilation - self.groups = groups - attr_code = f""" + def execute(self, + x, + weight, + bias=None, + stride=1, + padding=0, + dilation=1, + groups=1): + self.input = x + self.weight = weight + self.bias = bias + padding = _pair(padding) + stride = _pair(stride) + dilation = _pair(dilation) + out_channels = weight.shape[0] + if groups <= 0: + raise ValueError("groups must be a positive integer") + self.padding = padding + self.stride = stride + self.dilation = dilation + self.groups = groups + attr_code = f""" op.jt_name = "conv2d"; ConvAttr *attr = new ConvAttr(); attr->convStrides = {{ {stride[0]}, {stride[1]} }}; @@ -134,49 +138,48 @@ class ConvACL(jt.Function): attr->convOutPads = {{1,1}}; op.op_attr.reset(attr); """ - input_height, input_width = x.shape[-2:] - kernel_height, kernel_width = weight.shape[-2:] + input_height, input_width = x.shape[-2:] + kernel_height, kernel_width = weight.shape[-2:] - output_height = (input_height + 2 * padding[0] - dilation[0] * - (kernel_height - 1) - 1) // stride[0] + 1 - output_width = (input_width + 2 * padding[1] - dilation[1] * - (kernel_width - 1) - 1) // stride[1] + 1 + output_height = (input_height + 2 * padding[0] - dilation[0] * + (kernel_height - 1) - 1) // stride[0] + 1 + output_width = (input_width + 2 * padding[1] - dilation[1] * + (kernel_width - 1) - 1) // stride[1] + 1 - output_shape = (x.shape[0], out_channels, output_height, - output_width) + output_shape = (x.shape[0], out_channels, output_height, output_width) - inputs = [x, weight] - if bias is not None: - inputs.append(bias) - result = conv_forward( - "Conv2d", - inputs, - output_dtypes=[x.dtype], - output_shapes=[output_shape], - attr_code=attr_code, - )[0] - return result + inputs = [x, weight] + if bias is not None: + inputs.append(bias) + result = conv_forward( + "Conv2d", + inputs, + output_dtypes=[x.dtype], + output_shapes=[output_shape], + attr_code=attr_code, + )[0] + return result - def grad(self, grad_output): - x = self.input - weight = self.weight - bias = self.bias - inputs = [grad_output, x, weight] - if bias is not None: - inputs.append(bias) - output_shapes = [x.shape, weight.shape] - output_dtypes = [x.dtype, weight.dtype] - if bias is not None: - output_shapes.append(bias.shape) - output_dtypes.append(bias.dtype) - else: - output_shapes.append([1]) - output_dtypes.append(x.dtype) - padding = self.padding - stride = self.stride - dilation = self.dilation - groups = self.groups - attr_code = f""" + def grad(self, grad_output): + x = self.input + weight = self.weight + bias = self.bias + inputs = [grad_output, x, weight] + if bias is not None: + inputs.append(bias) + output_shapes = [x.shape, weight.shape] + output_dtypes = [x.dtype, weight.dtype] + if bias is not None: + output_shapes.append(bias.shape) + output_dtypes.append(bias.dtype) + else: + output_shapes.append([1]) + output_dtypes.append(x.dtype) + padding = self.padding + stride = self.stride + dilation = self.dilation + groups = self.groups + attr_code = f""" op.jt_name = "conv2dbackward"; ConvAttr *attr = new ConvAttr(); attr->convStrides = {{ {stride[0]}, {stride[1]} }}; @@ -186,12 +189,12 @@ class ConvACL(jt.Function): attr->convOutPads = {{ 1,1}}; op.op_attr.reset(attr); """ - results = acl_cmd_forward("Conv2dBackward", - inputs, - output_dtypes=output_dtypes, - output_shapes=output_shapes, - attr_code=attr_code) - if self.bias is None: - return results[0], results[1] + results = acl_cmd_forward("Conv2dBackward", + inputs, + output_dtypes=output_dtypes, + output_shapes=output_shapes, + attr_code=attr_code) + if self.bias is None: + return results[0], results[1] - return results \ No newline at end of file + return results diff --git a/python/jittor/extern/acl/aclops/conv_op_acl.cc b/python/jittor/extern/acl/aclops/conv_op_acl.cc index b7865af3..6e8b6ebf 100644 --- a/python/jittor/extern/acl/aclops/conv_op_acl.cc +++ b/python/jittor/extern/acl/aclops/conv_op_acl.cc @@ -34,7 +34,7 @@ namespace jittor { use_nchw = true; } - + void ConvOpRunner::executeOp(std::unordered_map::iterator &it) { // for conv @@ -66,7 +66,7 @@ namespace jittor ret = aclnnConvolution(workspaceAddr, workspaceSize, executor, aclstream); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnxxx failed. ERROR: %d\n", name.c_str(), ret); return); - // syncRun(); + syncRun(); aclDestroyIntArray(strides); aclDestroyIntArray(pads); diff --git a/python/jittor/extern/acl/aclops/conv_op_acl.h b/python/jittor/extern/acl/aclops/conv_op_acl.h index 3c03e8e5..3361d4af 100644 --- a/python/jittor/extern/acl/aclops/conv_op_acl.h +++ b/python/jittor/extern/acl/aclops/conv_op_acl.h @@ -9,6 +9,7 @@ namespace jittor protected: void executeOp(std::unordered_map::iterator &it) override; + public: ConvOpRunner(); }; diff --git a/python/jittor/extern/acl/aclops/cumsum_op.py b/python/jittor/extern/acl/aclops/cumsum_op.py index cf1ba4fa..28ce48b8 100644 --- a/python/jittor/extern/acl/aclops/cumsum_op.py +++ b/python/jittor/extern/acl/aclops/cumsum_op.py @@ -11,13 +11,14 @@ import numpy as np from typing import Union from collections.abc import Sequence, Iterable + def cumsum_cmd(name: str, - inputs: list, - output_dtypes: list = None, - output_shapes: list = None, - attr_code: str = "", - attr_header: str = "", - outputs: list = None): + inputs: list, + output_dtypes: list = None, + output_shapes: list = None, + attr_code: str = "", + attr_header: str = "", + outputs: list = None): attr_header = "\nnamespace jittor{" + attr_header + "}\n" cuda_header = ''' @@ -51,6 +52,7 @@ def cumsum_cmd(name: str, {attr_code} op.run();""") + class CumsumACL(jt.Function): def __init__(self): @@ -85,15 +87,15 @@ class CumsumACL(jt.Function): op.op_attr.reset(attr); """ flipped_grad_output = cumsum_cmd("Flip", [grad_output], - output_dtypes=[grad_output.dtype], - output_shapes=[grad_output.shape], - attr_code=flip_attr_code)[0] + output_dtypes=[grad_output.dtype], + output_shapes=[grad_output.shape], + attr_code=flip_attr_code)[0] cumulative_grad = cumsum_cmd("Cumsum", [flipped_grad_output], - output_dtypes=[grad_output.dtype], - output_shapes=[grad_output.shape], - attr_code=cumsum_attr_code)[0] + output_dtypes=[grad_output.dtype], + output_shapes=[grad_output.shape], + attr_code=cumsum_attr_code)[0] grad_input = cumsum_cmd("Flip", [cumulative_grad], output_dtypes=[grad_output.dtype], output_shapes=[grad_output.shape], attr_code=flip_attr_code)[0] - return grad_input \ No newline at end of file + return grad_input diff --git a/python/jittor/extern/acl/aclops/cumsum_op_acl.h b/python/jittor/extern/acl/aclops/cumsum_op_acl.h index a7ad343f..1b9888f1 100644 --- a/python/jittor/extern/acl/aclops/cumsum_op_acl.h +++ b/python/jittor/extern/acl/aclops/cumsum_op_acl.h @@ -9,6 +9,7 @@ namespace jittor protected: void executeOp(std::unordered_map::iterator &it) override; + public: CumsumOpRunner(); }; diff --git a/python/jittor/extern/acl/aclops/dropout_op.py b/python/jittor/extern/acl/aclops/dropout_op.py index 862a26b5..c7f3327a 100644 --- a/python/jittor/extern/acl/aclops/dropout_op.py +++ b/python/jittor/extern/acl/aclops/dropout_op.py @@ -11,13 +11,14 @@ import numpy as np from typing import Union from collections.abc import Sequence, Iterable + def dropout_cmd(name: str, - inputs: list, - output_dtypes: list = None, - output_shapes: list = None, - attr_code: str = "", - attr_header: str = "", - outputs: list = None): + inputs: list, + output_dtypes: list = None, + output_shapes: list = None, + attr_code: str = "", + attr_header: str = "", + outputs: list = None): attr_header = "\nnamespace jittor{" + attr_header + "}\n" cuda_header = ''' @@ -51,6 +52,7 @@ def dropout_cmd(name: str, {attr_code} op.run();""") + class DropoutACL(jt.Function): def __init__(self): @@ -71,9 +73,9 @@ class DropoutACL(jt.Function): op.op_attr.reset(attr); """ result = dropout_cmd("Dropout", [x], - output_dtypes=[x.dtype, "uint8"], - output_shapes=[x.shape, mask_shape], - attr_code=attr_code) + output_dtypes=[x.dtype, "uint8"], + output_shapes=[x.shape, mask_shape], + attr_code=attr_code) self.maskout = result[1] return result[0] @@ -85,8 +87,8 @@ class DropoutACL(jt.Function): op.op_attr.reset(attr); """ grad_input = dropout_cmd("DropoutBackward", - [grad_output, self.maskout], - output_dtypes=[grad_output.dtype], - output_shapes=[grad_output.shape], - attr_code=attr_code)[0] - return grad_input \ No newline at end of file + [grad_output, self.maskout], + output_dtypes=[grad_output.dtype], + output_shapes=[grad_output.shape], + attr_code=attr_code)[0] + return grad_input diff --git a/python/jittor/extern/acl/aclops/dropout_op_acl.h b/python/jittor/extern/acl/aclops/dropout_op_acl.h index 4e50b316..3380b0ec 100644 --- a/python/jittor/extern/acl/aclops/dropout_op_acl.h +++ b/python/jittor/extern/acl/aclops/dropout_op_acl.h @@ -9,6 +9,7 @@ namespace jittor protected: void executeOp(std::unordered_map::iterator &it) override; + public: DropoutOpRunner(); }; @@ -18,6 +19,7 @@ namespace jittor protected: void executeOp(std::unordered_map::iterator &it) override; + public: DropoutBackwardOpRunner(); }; diff --git a/python/jittor/extern/acl/aclops/flashattention_op.py b/python/jittor/extern/acl/aclops/flashattention_op.py index 2061b548..ce220708 100644 --- a/python/jittor/extern/acl/aclops/flashattention_op.py +++ b/python/jittor/extern/acl/aclops/flashattention_op.py @@ -11,13 +11,14 @@ import numpy as np from typing import Union from collections.abc import Sequence, Iterable + def flashattention_cmd(name: str, - inputs: list, - output_dtypes: list = None, - output_shapes: list = None, - attr_code: str = "", - attr_header: str = "", - outputs: list = None): + inputs: list, + output_dtypes: list = None, + output_shapes: list = None, + attr_code: str = "", + attr_header: str = "", + outputs: list = None): attr_header = "\nnamespace jittor{" + attr_header + "}\n" cuda_header = ''' @@ -51,21 +52,22 @@ def flashattention_cmd(name: str, {attr_code} op.run();""") + class FlashAttentionACL(jt.Function): def __init__(self, - headnum, - layout="BNSD", - prefix=None, - qstart=None, - kvstart=None, - scale=1.0, - prob=1.0, - pretokens=2147483647, - nexttokens=2147483647, - innerprecise=0, - sparsemode=0, - psetype=1): + headnum, + layout="BNSD", + prefix=None, + qstart=None, + kvstart=None, + scale=1.0, + prob=1.0, + pretokens=2147483647, + nexttokens=2147483647, + innerprecise=0, + sparsemode=0, + psetype=1): self.headnum = headnum self.layout = layout self.scale = scale @@ -116,9 +118,7 @@ class FlashAttentionACL(jt.Function): self.prefix = self.prefix if self.prefix else [0 for _ in range(B)] self.qstart = self.qstart if self.qstart else [0 for _ in range(B)] - self.kvstart = self.kvstart if self.kvstart else [ - 0 for _ in range(B) - ] + self.kvstart = self.kvstart if self.kvstart else [0 for _ in range(B)] self.hasRealshift = (not realshift == None) self.hasDropmask = (not dropMask == None) @@ -126,8 +126,7 @@ class FlashAttentionACL(jt.Function): self.hasAttenmask = (not attenMask == None) # 待定,目前设为nullptr - self.realshift = realshift if realshift else jt.zeros( - B, N, SQ, SKV) + self.realshift = realshift if realshift else jt.zeros(B, N, SQ, SKV) self.dropMask = dropMask if dropMask else jt.ones(B, N, SQ, SKV) self.paddingMask = paddingMask if paddingMask else jt.zeros( B, N, SQ, SKV) @@ -207,4 +206,4 @@ class FlashAttentionACL(jt.Function): output_dtypes=[self.q.dtype, self.k.dtype, self.v.dtype], output_shapes=[self.q.shape, self.k.shape, self.v.shape], attr_code=attr_code) - return result \ No newline at end of file + return result diff --git a/python/jittor/extern/acl/aclops/flashattention_op_acl.cc b/python/jittor/extern/acl/aclops/flashattention_op_acl.cc index b5c554fa..43a71ab7 100644 --- a/python/jittor/extern/acl/aclops/flashattention_op_acl.cc +++ b/python/jittor/extern/acl/aclops/flashattention_op_acl.cc @@ -58,7 +58,6 @@ namespace jittor return; } - FlashAttentionBackwardOpRunner::FlashAttentionBackwardOpRunner() : BaseOpRunner("FlashAttentionBackward") { } diff --git a/python/jittor/extern/acl/aclops/flashattention_op_acl.h b/python/jittor/extern/acl/aclops/flashattention_op_acl.h index c81a35e3..16c02caa 100644 --- a/python/jittor/extern/acl/aclops/flashattention_op_acl.h +++ b/python/jittor/extern/acl/aclops/flashattention_op_acl.h @@ -9,6 +9,7 @@ namespace jittor protected: void executeOp(std::unordered_map::iterator &it) override; + public: FlashAttentionOpRunner(); }; @@ -18,6 +19,7 @@ namespace jittor protected: void executeOp(std::unordered_map::iterator &it) override; + public: FlashAttentionBackwardOpRunner(); }; diff --git a/python/jittor/extern/acl/aclops/flip_op.py b/python/jittor/extern/acl/aclops/flip_op.py index 31757db1..f05c8be7 100644 --- a/python/jittor/extern/acl/aclops/flip_op.py +++ b/python/jittor/extern/acl/aclops/flip_op.py @@ -11,13 +11,14 @@ import numpy as np from typing import Union from collections.abc import Sequence, Iterable + def flip_cmd(name: str, - inputs: list, - output_dtypes: list = None, - output_shapes: list = None, - attr_code: str = "", - attr_header: str = "", - outputs: list = None): + inputs: list, + output_dtypes: list = None, + output_shapes: list = None, + attr_code: str = "", + attr_header: str = "", + outputs: list = None): attr_header = "\nnamespace jittor{" + attr_header + "}\n" cuda_header = ''' @@ -51,6 +52,7 @@ def flip_cmd(name: str, {attr_code} op.run();""") + class FlipACL(jt.Function): def __init__(self): @@ -70,14 +72,14 @@ class FlipACL(jt.Function): """ self.attr_code = attr_code result = flip_cmd("Flip", [input], - output_dtypes=[input.dtype], - output_shapes=[input.shape], - attr_code=self.attr_code)[0] + output_dtypes=[input.dtype], + output_shapes=[input.shape], + attr_code=self.attr_code)[0] return result def grad(self, grad_output): grad_input = flip_cmd("Flip", [grad_output], - output_dtypes=[grad_output.dtype], - output_shapes=[grad_output.shape], - attr_code=self.attr_code)[0] - return grad_input \ No newline at end of file + output_dtypes=[grad_output.dtype], + output_shapes=[grad_output.shape], + attr_code=self.attr_code)[0] + return grad_input diff --git a/python/jittor/extern/acl/aclops/flip_op_acl.h b/python/jittor/extern/acl/aclops/flip_op_acl.h index 05bf1e62..5b53700a 100644 --- a/python/jittor/extern/acl/aclops/flip_op_acl.h +++ b/python/jittor/extern/acl/aclops/flip_op_acl.h @@ -9,6 +9,7 @@ namespace jittor protected: void executeOp(std::unordered_map::iterator &it) override; + public: FlipOpRunner(); }; diff --git a/python/jittor/extern/acl/aclops/floor_op.py b/python/jittor/extern/acl/aclops/floor_op.py index 280f6f6d..35bed012 100644 --- a/python/jittor/extern/acl/aclops/floor_op.py +++ b/python/jittor/extern/acl/aclops/floor_op.py @@ -11,13 +11,14 @@ import numpy as np from typing import Union from collections.abc import Sequence, Iterable + def floor_cmd(name: str, - inputs: list, - output_dtypes: list = None, - output_shapes: list = None, - attr_code: str = "", - attr_header: str = "", - outputs: list = None): + inputs: list, + output_dtypes: list = None, + output_shapes: list = None, + attr_code: str = "", + attr_header: str = "", + outputs: list = None): attr_header = "\nnamespace jittor{" + attr_header + "}\n" cuda_header = ''' @@ -51,6 +52,7 @@ def floor_cmd(name: str, {attr_code} op.run();""") + class FloorIntACL(jt.Function): def __init__(self): @@ -59,10 +61,10 @@ class FloorIntACL(jt.Function): def execute(self, input): self.shape = input.shape result = floor_cmd("Floor", [input], - output_dtypes=[input.dtype], - output_shapes=[input.shape], - attr_code="op.jt_name=\"floor\";")[0] + output_dtypes=[input.dtype], + output_shapes=[input.shape], + attr_code="op.jt_name=\"floor\";")[0] return result def grad(self, grad_output): - return jt.zeros(self.shape, dtype=grad_output.dtype) \ No newline at end of file + return jt.zeros(self.shape, dtype=grad_output.dtype) diff --git a/python/jittor/extern/acl/aclops/floor_op_acl.h b/python/jittor/extern/acl/aclops/floor_op_acl.h index ee235c3e..3e228b16 100644 --- a/python/jittor/extern/acl/aclops/floor_op_acl.h +++ b/python/jittor/extern/acl/aclops/floor_op_acl.h @@ -9,6 +9,7 @@ namespace jittor protected: void executeOp(std::unordered_map::iterator &it) override; + public: FloorOpRunner(); }; diff --git a/python/jittor/extern/acl/aclops/gather_scatter_op.py b/python/jittor/extern/acl/aclops/gather_scatter_op.py index f1f0c80d..748c5718 100644 --- a/python/jittor/extern/acl/aclops/gather_scatter_op.py +++ b/python/jittor/extern/acl/aclops/gather_scatter_op.py @@ -11,13 +11,14 @@ import numpy as np from typing import Union from collections.abc import Sequence, Iterable + def gather_scatter_cmd(name: str, - inputs: list, - output_dtypes: list = None, - output_shapes: list = None, - attr_code: str = "", - attr_header: str = "", - outputs: list = None): + inputs: list, + output_dtypes: list = None, + output_shapes: list = None, + attr_code: str = "", + attr_header: str = "", + outputs: list = None): attr_header = "\nnamespace jittor{" + attr_header + "}\n" cuda_header = ''' @@ -51,6 +52,7 @@ def gather_scatter_cmd(name: str, {attr_code} op.run();""") + class GatherACL(jt.Function): def __init__(self): @@ -66,9 +68,9 @@ class GatherACL(jt.Function): op.op_attr.reset(attr); """ result = gather_scatter_cmd("Gather", [input, index], - output_dtypes=[input.dtype], - output_shapes=[index.shape], - attr_code=attr_code)[0] + output_dtypes=[input.dtype], + output_shapes=[index.shape], + attr_code=attr_code)[0] return result def grad(self, grad_output): @@ -80,12 +82,14 @@ class GatherACL(jt.Function): attr->reduction = {1}; op.op_attr.reset(attr); """ - grad_input = gather_scatter_cmd("Scatter", [tmp, self.index, grad_output], - output_dtypes=[grad_output.dtype], - output_shapes=[tmp.shape], - attr_code=attr_code)[0] + grad_input = gather_scatter_cmd("Scatter", + [tmp, self.index, grad_output], + output_dtypes=[grad_output.dtype], + output_shapes=[tmp.shape], + attr_code=attr_code)[0] return grad_input + class ScatterACL(jt.Function): def __init__(self): @@ -103,9 +107,9 @@ class ScatterACL(jt.Function): op.op_attr.reset(attr); """ result = gather_scatter_cmd("Scatter", [input, self.index, src], - output_dtypes=[input.dtype], - output_shapes=[input.shape], - attr_code=attr_code)[0] + output_dtypes=[input.dtype], + output_shapes=[input.shape], + attr_code=attr_code)[0] return result def grad(self, grad_output): @@ -116,7 +120,7 @@ class ScatterACL(jt.Function): op.op_attr.reset(attr); """ grad_input = gather_scatter_cmd("Gather", [grad_output, self.index], - output_dtypes=[grad_output.dtype], - output_shapes=[self.index.shape], - attr_code=attr_code)[0] - return grad_output, None, None, grad_input \ No newline at end of file + output_dtypes=[grad_output.dtype], + output_shapes=[self.index.shape], + attr_code=attr_code)[0] + return grad_output, None, None, grad_input diff --git a/python/jittor/extern/acl/aclops/gather_scatter_op_acl.cc b/python/jittor/extern/acl/aclops/gather_scatter_op_acl.cc index cab4683b..871f5e83 100644 --- a/python/jittor/extern/acl/aclops/gather_scatter_op_acl.cc +++ b/python/jittor/extern/acl/aclops/gather_scatter_op_acl.cc @@ -77,5 +77,4 @@ namespace jittor return; } - } \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/gather_scatter_op_acl.h b/python/jittor/extern/acl/aclops/gather_scatter_op_acl.h index 9ccd876e..dd95814f 100644 --- a/python/jittor/extern/acl/aclops/gather_scatter_op_acl.h +++ b/python/jittor/extern/acl/aclops/gather_scatter_op_acl.h @@ -9,16 +9,17 @@ namespace jittor protected: void executeOp(std::unordered_map::iterator &it) override; + public: GatherOpRunner(); }; - class ScatterOpRunner : public BaseOpRunner { protected: void executeOp(std::unordered_map::iterator &it) override; + public: ScatterOpRunner(); }; diff --git a/python/jittor/extern/acl/aclops/getitem_op.py b/python/jittor/extern/acl/aclops/getitem_op.py index b9d36351..91fc5d02 100644 --- a/python/jittor/extern/acl/aclops/getitem_op.py +++ b/python/jittor/extern/acl/aclops/getitem_op.py @@ -11,13 +11,14 @@ import numpy as np from typing import Union from collections.abc import Sequence, Iterable + def getitem_cmd(name: str, - inputs: list, - output_dtypes: list = None, - output_shapes: list = None, - attr_code: str = "", - attr_header: str = "", - outputs: list = None): + inputs: list, + output_dtypes: list = None, + output_shapes: list = None, + attr_code: str = "", + attr_header: str = "", + outputs: list = None): attr_header = "\nnamespace jittor{" + attr_header + "}\n" cuda_header = ''' @@ -51,6 +52,7 @@ def getitem_cmd(name: str, {attr_code} op.run();""") + def getitem_forward(name: str, inputs: list, output_dtypes: list = None, @@ -90,7 +92,8 @@ def getitem_forward(name: str, op.add(out0, false); {attr_code} op.run();""", - data=extra_data) + data=extra_data) + def caculate_shape(tensors): if isinstance(tensors, jt.Var): @@ -105,6 +108,7 @@ def caculate_shape(tensors): else: assert False, f"not implemented for {type(tensors)}" + def can_broadcast_and_shape(shape1, shape2): """ 检查两个张量是否可以广播,并返回广播后的形状。 @@ -144,6 +148,7 @@ def can_broadcast_and_shape(shape1, shape2): return True, tuple(broadcast_shape) + class GetItemACL(jt.Function): def __init__(self): @@ -174,9 +179,9 @@ class GetItemACL(jt.Function): op.jt_name = "maskedselect"; """ result = getitem_cmd("MaskedSelect", - inputs=inputs, - outputs=outputs, - attr_code=attr_code)[0] + inputs=inputs, + outputs=outputs, + attr_code=attr_code)[0] result = result[:output_len] result.sync() return result @@ -194,7 +199,7 @@ class GetItemACL(jt.Function): contains_slice = False for s in slices: if not isinstance(s, jt.Var) and (isinstance(s, slice) - or s == Ellipsis): + or s == Ellipsis): contains_slice = True break if not contains_slice: @@ -212,9 +217,9 @@ class GetItemACL(jt.Function): output_shape = [1] for ii in slices: indices.append(jt.Var(ii).int32()) - if isinstance(slices[0], jt.Var) or isinstance( - slices[0], int) or isinstance( - slices[0], list) or isinstance(slices[0], tuple): + if isinstance(slices[0], + jt.Var) or isinstance(slices[0], int) or isinstance( + slices[0], list) or isinstance(slices[0], tuple): self.indices = indices inputs = [x] + indices attr_code = f""" @@ -222,10 +227,10 @@ class GetItemACL(jt.Function): """ self.type_ = 'index' result = getitem_cmd("Index", - inputs=inputs, - output_dtypes=[x.dtype], - output_shapes=[output_shape], - attr_code=attr_code)[0] + inputs=inputs, + output_dtypes=[x.dtype], + output_shapes=[output_shape], + attr_code=attr_code)[0] result.sync() return result assert contains_slice, "slice type error" @@ -235,8 +240,7 @@ class GetItemACL(jt.Function): if not isinstance(s, jt.Var) and s == Ellipsis: slices = slices[:slices.index(s)] + [ slice(None, None, None) - ] * (x_dim - len(slices) + 1) + slices[slices.index(s) + - 1:] + ] * (x_dim - len(slices) + 1) + slices[slices.index(s) + 1:] break slices = tuple(slices) @@ -313,11 +317,11 @@ class GetItemACL(jt.Function): op.op_attr.reset(attr); """ result = getitem_forward("SliceV2", - inputs, - output_dtypes=[x.dtype], - output_shapes=[jt.empty(sizes).shape], - attr_code=attr_code, - extra_data=extra_data)[0] + inputs, + output_dtypes=[x.dtype], + output_shapes=[jt.empty(sizes).shape], + attr_code=attr_code, + extra_data=extra_data)[0] self.squeeze_dims = squeeze_dims for dim in squeeze_dims[::-1]: result = jt.squeeze(result, dim) @@ -334,9 +338,9 @@ class GetItemACL(jt.Function): outputs = [jt.zeros(self.x_shape, dtype=grad_output.dtype)] # breakpoint() result = getitem_cmd("IndexPutImplAccumulate", - inputs=inputs, - outputs=outputs, - attr_code=attr_code)[0] + inputs=inputs, + outputs=outputs, + attr_code=attr_code)[0] result.sync() return result, None elif self.type_ == 'slicev2': @@ -401,9 +405,9 @@ class GetItemACL(jt.Function): inputs = [grad_output] outputs = [jt.zeros(self.x_shape, dtype=grad_output.dtype)] result = getitem_cmd("StridedSliceAssignV2", - inputs=inputs, - outputs=outputs, - attr_code=attr_code)[0] + inputs=inputs, + outputs=outputs, + attr_code=attr_code)[0] result.sync() if expand_dim: result = result.squeeze(-1) @@ -412,4 +416,4 @@ class GetItemACL(jt.Function): return self.mask.float() pass else: - assert False, f"grad not implemented for {self.type_}" \ No newline at end of file + assert False, f"grad not implemented for {self.type_}" diff --git a/python/jittor/extern/acl/aclops/getitem_op_acl.cc b/python/jittor/extern/acl/aclops/getitem_op_acl.cc index 27d16dd4..4c7c34d3 100644 --- a/python/jittor/extern/acl/aclops/getitem_op_acl.cc +++ b/python/jittor/extern/acl/aclops/getitem_op_acl.cc @@ -53,11 +53,10 @@ namespace jittor return; } - IndexOpRunner::IndexOpRunner() : BaseOpRunner("Index") { } - + void IndexOpRunner::executeOp(std::unordered_map::iterator &it) { auto input_num = in_.size(); @@ -81,7 +80,7 @@ namespace jittor SliceV2OpRunner::SliceV2OpRunner() : BaseOpRunner("SliceV2") { } - + void SliceV2OpRunner::executeOp(std::unordered_map::iterator &it) { auto attr = dynamic_cast(op_attr.get()); @@ -106,11 +105,10 @@ namespace jittor return; } - IndexPutImplAccumulateOpRunner::IndexPutImplAccumulateOpRunner() : BaseOpRunner("IndexPutImplAccumulate") { } - + void IndexPutImplAccumulateOpRunner::executeOp(std::unordered_map::iterator &it) { auto input_num = in_.size(); @@ -137,11 +135,9 @@ namespace jittor return; } - StridedSliceAssignV2OpRunner::StridedSliceAssignV2OpRunner() : BaseOpRunner("StridedSliceAssignV2") { } - void StridedSliceAssignV2OpRunner::executeOp(std::unordered_map::iterator &it) { @@ -162,9 +158,7 @@ namespace jittor ret = aclnnStridedSliceAssignV2(workspaceAddr, workspaceSize, executor, aclstream); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnStridedSliceAssignV2 failed. ERROR: %d\n", name.c_str(), ret); return); - // syncRun(); - - + syncRun(); return; } diff --git a/python/jittor/extern/acl/aclops/getitem_op_acl.h b/python/jittor/extern/acl/aclops/getitem_op_acl.h index 08426d41..481ab15f 100644 --- a/python/jittor/extern/acl/aclops/getitem_op_acl.h +++ b/python/jittor/extern/acl/aclops/getitem_op_acl.h @@ -9,6 +9,7 @@ namespace jittor protected: void executeOp(std::unordered_map::iterator &it) override; + public: MaskedSelectOpRunner(); }; @@ -18,6 +19,7 @@ namespace jittor protected: void executeOp(std::unordered_map::iterator &it) override; + public: IndexOpRunner(); }; @@ -27,6 +29,7 @@ namespace jittor protected: void executeOp(std::unordered_map::iterator &it) override; + public: SliceV2OpRunner(); }; @@ -36,6 +39,7 @@ namespace jittor protected: void executeOp(std::unordered_map::iterator &it) override; + public: IndexPutImplAccumulateOpRunner(); }; @@ -45,6 +49,7 @@ namespace jittor protected: void executeOp(std::unordered_map::iterator &it) override; + public: StridedSliceAssignV2OpRunner(); }; diff --git a/python/jittor/extern/acl/aclops/index_op.py b/python/jittor/extern/acl/aclops/index_op.py index a280bc83..087f2a97 100644 --- a/python/jittor/extern/acl/aclops/index_op.py +++ b/python/jittor/extern/acl/aclops/index_op.py @@ -11,14 +11,15 @@ import numpy as np from typing import Union from collections.abc import Sequence, Iterable + def range_forward(name: str, - inputs: list, - output_dtypes: list = None, - output_shapes: list = None, - attr_code: str = "", - attr_header: str = "", - outputs: list = None, - extra_data: dict = {}): + inputs: list, + output_dtypes: list = None, + output_shapes: list = None, + attr_code: str = "", + attr_header: str = "", + outputs: list = None, + extra_data: dict = {}): attr_header = "\nnamespace jittor{" + attr_header + "}\n" cuda_header = ''' @@ -50,7 +51,8 @@ def range_forward(name: str, op.add(out0, false); {attr_code} op.run();""", - data=extra_data) + data=extra_data) + class IndexACL(jt.Function): @@ -85,15 +87,13 @@ class IndexACL(jt.Function): op.op_attr.reset(attr); """ result = range_forward("Range", [], - output_dtypes=[tmp.dtype], - output_shapes=[tmp.shape], - attr_code=range_attr_code, - extra_data=extra_data)[0] + output_dtypes=[tmp.dtype], + output_shapes=[tmp.shape], + attr_code=range_attr_code, + extra_data=extra_data)[0] broadcast_dims = list(range(len(inshape))) broadcast_dims.remove(d) - result = jt.broadcast(result, - shape=inshape, - dims=broadcast_dims) + result = jt.broadcast(result, shape=inshape, dims=broadcast_dims) results.append(result) if len(results) != 1 or dim_input == None: @@ -104,4 +104,4 @@ class IndexACL(jt.Function): return results def grad(self, grad_output): - return grad_output \ No newline at end of file + return grad_output diff --git a/python/jittor/extern/acl/aclops/index_op_acl.h b/python/jittor/extern/acl/aclops/index_op_acl.h index a244908f..e69bf39c 100644 --- a/python/jittor/extern/acl/aclops/index_op_acl.h +++ b/python/jittor/extern/acl/aclops/index_op_acl.h @@ -9,6 +9,7 @@ namespace jittor protected: void executeOp(std::unordered_map::iterator &it) override; + public: RangeOpRunner(); }; diff --git a/python/jittor/extern/acl/aclops/matmul_op.py b/python/jittor/extern/acl/aclops/matmul_op.py index 6f9b00a7..fbb8df71 100644 --- a/python/jittor/extern/acl/aclops/matmul_op.py +++ b/python/jittor/extern/acl/aclops/matmul_op.py @@ -12,15 +12,14 @@ from typing import Union from collections.abc import Sequence, Iterable - def matmul_forward(name: str, - inputs: list, - output_dtypes: list = None, - output_shapes: list = None, - attr_code: str = "", - attr_header: str = "", - outputs: list = None, - extra_data: dict = {}): + inputs: list, + output_dtypes: list = None, + output_shapes: list = None, + attr_code: str = "", + attr_header: str = "", + outputs: list = None, + extra_data: dict = {}): attr_header = "\nnamespace jittor{" + attr_header + "}\n" cuda_header = ''' @@ -38,7 +37,7 @@ def matmul_forward(name: str, input_code = '' for i in range(len(inputs)): input_code += f"op.add(in{i}, true);\n" - + return jt.code(outputs=outputs_, inputs=inputs, cuda_header=attr_header + cuda_header, @@ -49,83 +48,83 @@ def matmul_forward(name: str, op.add(out0, false); {attr_code} op.run();""", - data=extra_data) - + data=extra_data) + + class MatmulACL(jt.Function): - def __init__(self, trans_x2=False): - super(MatmulACL, self).__init__() - self.trans_x2 = trans_x2 + def __init__(self, trans_x2=False): + super(MatmulACL, self).__init__() + self.trans_x2 = trans_x2 - def execute(self, x1, x2): - self.input = [x1, x2] - result = matmul_forward( - "MatMul", [x1, x2], - output_dtypes=[x1.dtype], - output_shapes=[ - x1.shape[:-1] + - x2.shape[-2:-1] if self.trans_x2 else x1.shape[:-1] + - x2.shape[-1:] - ], - attr_code="op.jt_name=\"matmul_trans_1\";" - if self.trans_x2 else "op.jt_name=\"matmul\";")[0] - return result + def execute(self, x1, x2): + self.input = [x1, x2] + result = matmul_forward( + "MatMul", [x1, x2], + output_dtypes=[x1.dtype], + output_shapes=[ + x1.shape[:-1] + + x2.shape[-2:-1] if self.trans_x2 else x1.shape[:-1] + + x2.shape[-1:] + ], + attr_code="op.jt_name=\"matmul_trans_1\";" + if self.trans_x2 else "op.jt_name=\"matmul\";")[0] + return result - def grad(self, grad_output): - x1, x2 = self.input - if len(x1) != len(x2): - reshape_grad_x2 = True + def grad(self, grad_output): + x1, x2 = self.input + if len(x1) != len(x2): + reshape_grad_x2 = True + else: + reshape_grad_x2 = False + grad_x1 = matmul_forward( + "MatMul", [grad_output, x2], + output_dtypes=[x1.dtype], + output_shapes=[ + grad_output.shape[:-1] + x2.shape[-2:-1] if not self.trans_x2 + else grad_output.shape[:-1] + x2.shape[-1:] + ], + attr_code="op.jt_name=\"matmul_trans_1\";" + if not self.trans_x2 else "op.jt_name=\"matmul\";")[0] + + if self.trans_x2: + if reshape_grad_x2: + output_shape = grad_output.shape[1:-2] + grad_output.shape[ + -1:] + x1.shape[-1:] + grad_x2 = matmul_forward( + "MatMul", [ + grad_output.reshape(-1, grad_output.shape[-1]), + x1.reshape(-1, x1.shape[-1]) + ], + output_dtypes=[x2.dtype], + output_shapes=[output_shape], + attr_code="op.jt_name=\"matmul_trans_0\";")[0] else: - reshape_grad_x2 = False - grad_x1 = matmul_forward( - "MatMul", [grad_output, x2], - output_dtypes=[x1.dtype], - output_shapes=[ - grad_output.shape[:-1] + x2.shape[-2:-1] - if not self.trans_x2 else grad_output.shape[:-1] + - x2.shape[-1:] - ], - attr_code="op.jt_name=\"matmul_trans_1\";" - if not self.trans_x2 else "op.jt_name=\"matmul\";")[0] - - if self.trans_x2: - if reshape_grad_x2: - output_shape = grad_output.shape[1:-2] + grad_output.shape[ - -1:] + x1.shape[-1:] - grad_x2 = matmul_forward( - "MatMul", [ - grad_output.reshape(-1, grad_output.shape[-1]), - x1.reshape(-1, x1.shape[-1]) - ], - output_dtypes=[x2.dtype], - output_shapes=[output_shape], - attr_code="op.jt_name=\"matmul_trans_0\";")[0] - else: - output_shape = grad_output.shape[:-2] + grad_output.shape[ - -1:] + x1.shape[-1:] - grad_x2 = matmul_forward( - "MatMul", [grad_output, x1], - output_dtypes=[x2.dtype], - output_shapes=[output_shape], - attr_code="op.jt_name=\"matmul_trans_0\";")[0] + output_shape = grad_output.shape[:-2] + grad_output.shape[ + -1:] + x1.shape[-1:] + grad_x2 = matmul_forward( + "MatMul", [grad_output, x1], + output_dtypes=[x2.dtype], + output_shapes=[output_shape], + attr_code="op.jt_name=\"matmul_trans_0\";")[0] + else: + if reshape_grad_x2: + output_shape = x1.shape[1:-2] + x1.shape[ + -1:] + grad_output.shape[-1:] + grad_x2 = matmul_forward( + "MatMul", [ + x1.reshape(-1, x1.shape[-1]), + grad_output.reshape(-1, grad_output.shape[-1]) + ], + output_dtypes=[x2.dtype], + output_shapes=[output_shape], + attr_code="op.jt_name=\"matmul_trans_0\";")[0] else: - if reshape_grad_x2: - output_shape = x1.shape[1:-2] + x1.shape[ - -1:] + grad_output.shape[-1:] - grad_x2 = matmul_forward( - "MatMul", [ - x1.reshape(-1, x1.shape[-1]), - grad_output.reshape(-1, grad_output.shape[-1]) - ], - output_dtypes=[x2.dtype], - output_shapes=[output_shape], - attr_code="op.jt_name=\"matmul_trans_0\";")[0] - else: - output_shape = x1.shape[:-2] + x1.shape[ - -1:] + grad_output.shape[-1:] - grad_x2 = matmul_forward( - "MatMul", [x1, grad_output], - output_dtypes=[x2.dtype], - output_shapes=[output_shape], - attr_code="op.jt_name=\"matmul_trans_0\";")[0] - return grad_x1, grad_x2 \ No newline at end of file + output_shape = x1.shape[:-2] + x1.shape[ + -1:] + grad_output.shape[-1:] + grad_x2 = matmul_forward( + "MatMul", [x1, grad_output], + output_dtypes=[x2.dtype], + output_shapes=[output_shape], + attr_code="op.jt_name=\"matmul_trans_0\";")[0] + return grad_x1, grad_x2 diff --git a/python/jittor/extern/acl/aclops/matmul_op_acl.cc b/python/jittor/extern/acl/aclops/matmul_op_acl.cc index 304b3298..af109cbb 100644 --- a/python/jittor/extern/acl/aclops/matmul_op_acl.cc +++ b/python/jittor/extern/acl/aclops/matmul_op_acl.cc @@ -49,7 +49,7 @@ namespace jittor for (int idx = 0; idx < input_num; idx++) { inputTensors.push_back(nullptr); - if ((jt_name == "matmul_trans_1" && idx == 1) || (jt_name == "matmul_trans_0" && idx == 0) ) + if ((jt_name == "matmul_trans_1" && idx == 1) || (jt_name == "matmul_trans_0" && idx == 0)) { auto ret = CreateFakeTransAclTensor(inputShapes[idx], in_[idx]->mem_ptr, in_[idx]->size, get_dtype(in_[idx]->dtype()), &inputTensors[idx], use_nchw); CHECK_RET(ret == ACL_SUCCESS, return); @@ -63,8 +63,8 @@ namespace jittor } void MatMulOpRunner::executeOp(std::unordered_map::iterator &it) { - - ret = aclnnMatmulGetWorkspaceSize(inputTensors[0], inputTensors[1], outputTensors[0], 1, &workspaceSize, &executor); + + ret = aclnnMatmulGetWorkspaceSize(inputTensors[0], inputTensors[1], outputTensors[0], 1, &workspaceSize, &executor); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnMatmulGetWorkspaceSize failed. ERROR: %d\n", name.c_str(), ret); return); if (workspaceSize > 0) { @@ -72,6 +72,6 @@ namespace jittor } ret = aclnnMatmul(workspaceAddr, workspaceSize, executor, aclstream); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnMatmul failed. ERROR: %d\n", name.c_str(), ret); return); - // syncRun(); + syncRun(); } } \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/matmul_op_acl.h b/python/jittor/extern/acl/aclops/matmul_op_acl.h index ce1f2507..ab82edc0 100644 --- a/python/jittor/extern/acl/aclops/matmul_op_acl.h +++ b/python/jittor/extern/acl/aclops/matmul_op_acl.h @@ -10,6 +10,7 @@ namespace jittor protected: void setupInputDesc() override; void executeOp(std::unordered_map::iterator &it) override; + public: MatMulOpRunner(); }; diff --git a/python/jittor/extern/acl/aclops/nantonum_op.py b/python/jittor/extern/acl/aclops/nantonum_op.py index 3816f698..2a36c999 100644 --- a/python/jittor/extern/acl/aclops/nantonum_op.py +++ b/python/jittor/extern/acl/aclops/nantonum_op.py @@ -11,13 +11,14 @@ import numpy as np from typing import Union from collections.abc import Sequence, Iterable + def nantonum_cmd(name: str, - inputs: list, - output_dtypes: list = None, - output_shapes: list = None, - attr_code: str = "", - attr_header: str = "", - outputs: list = None): + inputs: list, + output_dtypes: list = None, + output_shapes: list = None, + attr_code: str = "", + attr_header: str = "", + outputs: list = None): attr_header = "\nnamespace jittor{" + attr_header + "}\n" cuda_header = ''' @@ -51,6 +52,7 @@ def nantonum_cmd(name: str, {attr_code} op.run();""") + class NanToNumACL(jt.Function): def __init__(self): @@ -67,7 +69,7 @@ class NanToNumACL(jt.Function): """ self.attr_code = attr_code result = nantonum_cmd("NanToNum", [input], - output_dtypes=[input[0].dtype], - output_shapes=[input.shape], - attr_code=self.attr_code)[0] - return result \ No newline at end of file + output_dtypes=[input[0].dtype], + output_shapes=[input.shape], + attr_code=self.attr_code)[0] + return result diff --git a/python/jittor/extern/acl/aclops/nantonum_op_acl.h b/python/jittor/extern/acl/aclops/nantonum_op_acl.h index a52f37b4..924c0080 100644 --- a/python/jittor/extern/acl/aclops/nantonum_op_acl.h +++ b/python/jittor/extern/acl/aclops/nantonum_op_acl.h @@ -9,6 +9,7 @@ namespace jittor protected: void executeOp(std::unordered_map::iterator &it) override; + public: NanToNumOpRunner(); }; diff --git a/python/jittor/extern/acl/aclops/pool_op.py b/python/jittor/extern/acl/aclops/pool_op.py index 75953ed4..583420a6 100644 --- a/python/jittor/extern/acl/aclops/pool_op.py +++ b/python/jittor/extern/acl/aclops/pool_op.py @@ -11,13 +11,14 @@ import numpy as np from typing import Union from collections.abc import Sequence, Iterable + def pool_cmd(name: str, - inputs: list, - output_dtypes: list = None, - output_shapes: list = None, - attr_code: str = "", - attr_header: str = "", - outputs: list = None): + inputs: list, + output_dtypes: list = None, + output_shapes: list = None, + attr_code: str = "", + attr_header: str = "", + outputs: list = None): attr_header = "\nnamespace jittor{" + attr_header + "}\n" cuda_header = ''' @@ -51,33 +52,32 @@ def pool_cmd(name: str, {attr_code} op.run();""") + class PoolACL(jt.Function): def __init__(self, - kernel_size, - stride=None, - padding=0, - dilation=None, - return_indices=None, - ceil_mode=False, - count_include_pad=True, - op='maximum'): + kernel_size, + stride=None, + padding=0, + dilation=None, + return_indices=None, + ceil_mode=False, + count_include_pad=True, + op='maximum'): self.kernel_size = kernel_size if isinstance( kernel_size, tuple) else (kernel_size, kernel_size) stride = stride if stride else kernel_size - self.stride = stride if isinstance(stride, tuple) else (stride, - stride) + self.stride = stride if isinstance(stride, tuple) else (stride, stride) self.padding = padding if isinstance(padding, tuple) else (padding, - padding) + padding) dilation = dilation if dilation else 1 assert dilation == 1 - self.dilation = dilation if isinstance( - dilation, tuple) else (dilation, dilation) + self.dilation = dilation if isinstance(dilation, tuple) else (dilation, + dilation) for item in self.kernel_size: if item <= 0: raise RuntimeError( - f"kernel_size must be greater than zero, but got {item}" - ) + f"kernel_size must be greater than zero, but got {item}") for item in self.stride: if item <= 0: raise RuntimeError( @@ -108,7 +108,7 @@ class PoolACL(jt.Function): kernel_height, kernel_width = self.kernel_size[-2:] output_height = (input_height + 2 * self.padding[0] - - (kernel_height - 1) - 1) // self.stride[0] + 1 + (kernel_height - 1) - 1) // self.stride[0] + 1 output_width = (input_width + 2 * self.padding[1] - (kernel_width - 1) - 1) // self.stride[1] + 1 @@ -161,16 +161,16 @@ class PoolACL(jt.Function): output_dtypes = [input.dtype] if self.op == 'maximum': result = pool_cmd("MaxpoolBackward", - inputs=[grad_output, input, self.index], - output_dtypes=output_dtypes, - output_shapes=output_shapes, - attr_code=attr_code)[0] + inputs=[grad_output, input, self.index], + output_dtypes=output_dtypes, + output_shapes=output_shapes, + attr_code=attr_code)[0] elif self.op == 'mean': result = pool_cmd("AvgpoolBackward", - inputs=[grad_output, input], - output_dtypes=output_dtypes, - output_shapes=output_shapes, - attr_code=attr_code)[0] + inputs=[grad_output, input], + output_dtypes=output_dtypes, + output_shapes=output_shapes, + attr_code=attr_code)[0] else: raise ValueError('no this type pool') - return result \ No newline at end of file + return result diff --git a/python/jittor/extern/acl/aclops/pool_op_acl.cc b/python/jittor/extern/acl/aclops/pool_op_acl.cc index 5542b03e..8781b4ee 100644 --- a/python/jittor/extern/acl/aclops/pool_op_acl.cc +++ b/python/jittor/extern/acl/aclops/pool_op_acl.cc @@ -71,7 +71,6 @@ namespace jittor return; } - AvgpoolOpRunner::AvgpoolOpRunner() : BaseOpRunner("Avgpool") { use_nchw = true; @@ -109,7 +108,6 @@ namespace jittor return; } - MaxpoolBackwardOpRunner::MaxpoolBackwardOpRunner() : BaseOpRunner("MaxpoolBackward") { use_nchw = true; @@ -150,8 +148,6 @@ namespace jittor return; } - - AvgpoolBackwardOpRunner::AvgpoolBackwardOpRunner() : BaseOpRunner("AvgpoolBackward") { use_nchw = true; diff --git a/python/jittor/extern/acl/aclops/pool_op_acl.h b/python/jittor/extern/acl/aclops/pool_op_acl.h index 342aa6c8..5116314a 100644 --- a/python/jittor/extern/acl/aclops/pool_op_acl.h +++ b/python/jittor/extern/acl/aclops/pool_op_acl.h @@ -9,6 +9,7 @@ namespace jittor protected: void executeOp(std::unordered_map::iterator &it) override; + public: MaxpoolOpRunner(); }; @@ -18,16 +19,17 @@ namespace jittor protected: void executeOp(std::unordered_map::iterator &it) override; + public: AvgpoolOpRunner(); }; - class MaxpoolBackwardOpRunner : public BaseOpRunner { protected: void executeOp(std::unordered_map::iterator &it) override; + public: MaxpoolBackwardOpRunner(); }; @@ -37,6 +39,7 @@ namespace jittor protected: void executeOp(std::unordered_map::iterator &it) override; + public: AvgpoolBackwardOpRunner(); }; diff --git a/python/jittor/extern/acl/aclops/random_op_acl.cc b/python/jittor/extern/acl/aclops/random_op_acl.cc index 988fd8ff..2fb18eba 100644 --- a/python/jittor/extern/acl/aclops/random_op_acl.cc +++ b/python/jittor/extern/acl/aclops/random_op_acl.cc @@ -44,7 +44,7 @@ namespace jittor void RandomOpRunner::executeOp(std::unordered_map::iterator &it) { auto attr = dynamic_cast(op_attr.get()); - if(name == "RandomUniform") + if (name == "RandomUniform") { ret = aclnnInplaceUniformGetWorkspaceSize(outputTensors[0], 0.0, 1.0, attr->seed, attr->offset, &workspaceSize, &executor); @@ -58,7 +58,7 @@ namespace jittor ret = aclnnInplaceUniform(workspaceAddr, workspaceSize, executor, aclstream); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnInplaceUniform failed. ERROR: %d\n", name.c_str(), ret); return); } - else if(name == "RandomNormal") + else if (name == "RandomNormal") { ret = aclnnInplaceNormalGetWorkspaceSize(outputTensors[0], 0.0, 1.0, attr->seed, attr->offset, &workspaceSize, &executor); @@ -76,7 +76,7 @@ namespace jittor { LOGf << "Not supported random type : " << name; } - // syncRun(); + syncRun(); return; } } \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/reduce_op_acl.cc b/python/jittor/extern/acl/aclops/reduce_op_acl.cc index 9907993a..e0c0b069 100644 --- a/python/jittor/extern/acl/aclops/reduce_op_acl.cc +++ b/python/jittor/extern/acl/aclops/reduce_op_acl.cc @@ -28,7 +28,6 @@ #include "aclnn/aclnn.h" #include "reduce_op_acl.h" - namespace jittor { ReduceOpRunner::ReduceOpRunner() : BaseOpRunner("reduce") diff --git a/python/jittor/extern/acl/aclops/relu_op.py b/python/jittor/extern/acl/aclops/relu_op.py index cc7c1c0e..c321a810 100644 --- a/python/jittor/extern/acl/aclops/relu_op.py +++ b/python/jittor/extern/acl/aclops/relu_op.py @@ -11,13 +11,14 @@ import numpy as np from typing import Union from collections.abc import Sequence, Iterable + def relu_cmd(name: str, - inputs: list, - output_dtypes: list = None, - output_shapes: list = None, - attr_code: str = "", - attr_header: str = "", - outputs: list = None): + inputs: list, + output_dtypes: list = None, + output_shapes: list = None, + attr_code: str = "", + attr_header: str = "", + outputs: list = None): attr_header = "\nnamespace jittor{" + attr_header + "}\n" cuda_header = ''' @@ -51,6 +52,7 @@ def relu_cmd(name: str, {attr_code} op.run();""") + class ReLUACL(jt.Function): def __init__(self): @@ -60,9 +62,9 @@ class ReLUACL(jt.Function): x = x.float32() self.input = x result = relu_cmd("ReLU", [x], - output_dtypes=[x.dtype], - output_shapes=[x.shape], - attr_code="op.jt_name=\"unary\";")[0] + output_dtypes=[x.dtype], + output_shapes=[x.shape], + attr_code="op.jt_name=\"unary\";")[0] return result def grad(self, grad_output): @@ -72,11 +74,11 @@ class ReLUACL(jt.Function): output_shapes=[self.input.shape], attr_code="op.jt_name=\"binary\";")[0] grad_input = relu_cmd("Mul", [grad_output, mask], - output_dtypes=[grad_output.dtype], - output_shapes=[grad_output.shape], - attr_code="op.jt_name=\"binary\";")[0] + output_dtypes=[grad_output.dtype], + output_shapes=[grad_output.shape], + attr_code="op.jt_name=\"binary\";")[0] return grad_input - + class LeakyReLUACL(jt.Function): @@ -93,9 +95,9 @@ class LeakyReLUACL(jt.Function): op.op_attr.reset(attr); """ result = relu_cmd("LeakyReLU", [x], - output_dtypes=[x.dtype], - output_shapes=[x.shape], - attr_code=attr_code)[0] + output_dtypes=[x.dtype], + output_shapes=[x.shape], + attr_code=attr_code)[0] self.negative_slope = negative_slope return result @@ -107,9 +109,8 @@ class LeakyReLUACL(jt.Function): attr->selfIsResult = false; op.op_attr.reset(attr); """ - grad_input = relu_cmd("LeakyReLUBackward", - [grad_output, self.input], - output_dtypes=[grad_output.dtype], - output_shapes=[grad_output.shape], - attr_code=attr_code)[0] - return grad_input \ No newline at end of file + grad_input = relu_cmd("LeakyReLUBackward", [grad_output, self.input], + output_dtypes=[grad_output.dtype], + output_shapes=[grad_output.shape], + attr_code=attr_code)[0] + return grad_input diff --git a/python/jittor/extern/acl/aclops/relu_op_acl.h b/python/jittor/extern/acl/aclops/relu_op_acl.h index 6ce713fe..c436dd1e 100644 --- a/python/jittor/extern/acl/aclops/relu_op_acl.h +++ b/python/jittor/extern/acl/aclops/relu_op_acl.h @@ -9,6 +9,7 @@ namespace jittor protected: void executeOp(std::unordered_map::iterator &it) override; + public: LeakyReLUOpRunner(); }; @@ -18,6 +19,7 @@ namespace jittor protected: void executeOp(std::unordered_map::iterator &it) override; + public: LeakyReLUBackwardOpRunner(); }; diff --git a/python/jittor/extern/acl/aclops/rope_op.py b/python/jittor/extern/acl/aclops/rope_op.py index 9853dbda..71269cf1 100644 --- a/python/jittor/extern/acl/aclops/rope_op.py +++ b/python/jittor/extern/acl/aclops/rope_op.py @@ -11,13 +11,14 @@ import numpy as np from typing import Union from collections.abc import Sequence, Iterable + def rope_cmd(name: str, - inputs: list, - output_dtypes: list = None, - output_shapes: list = None, - attr_code: str = "", - attr_header: str = "", - outputs: list = None): + inputs: list, + output_dtypes: list = None, + output_shapes: list = None, + attr_code: str = "", + attr_header: str = "", + outputs: list = None): attr_header = "\nnamespace jittor{" + attr_header + "}\n" cuda_header = ''' @@ -51,6 +52,7 @@ def rope_cmd(name: str, {attr_code} op.run();""") + class RopeACL(jt.Function): def __init__(self): @@ -67,16 +69,16 @@ class RopeACL(jt.Function): assert freq_cos is not None and freq_sin is not None inputs = [xq, xk, freq_cos, freq_sin] results = rope_cmd("RotaryPosEmb", - inputs, - output_dtypes=[ - xq.dtype, - ], - output_shapes=[ - xq.shape, - ], - attr_code=attr_code) + inputs, + output_dtypes=[ + xq.dtype, + ], + output_shapes=[ + xq.shape, + ], + attr_code=attr_code) results[0].sync() return inputs[0], inputs[1] def grad(self, grad_output): - return grad_output \ No newline at end of file + return grad_output diff --git a/python/jittor/extern/acl/aclops/rope_op_acl.h b/python/jittor/extern/acl/aclops/rope_op_acl.h index 2b3e7594..0f1b2996 100644 --- a/python/jittor/extern/acl/aclops/rope_op_acl.h +++ b/python/jittor/extern/acl/aclops/rope_op_acl.h @@ -9,6 +9,7 @@ namespace jittor protected: void executeOp(std::unordered_map::iterator &it) override; + public: RotaryPosEmbOpRunner(); }; diff --git a/python/jittor/extern/acl/aclops/setitem_op.py b/python/jittor/extern/acl/aclops/setitem_op.py index e14ebe02..58eb0f62 100644 --- a/python/jittor/extern/acl/aclops/setitem_op.py +++ b/python/jittor/extern/acl/aclops/setitem_op.py @@ -11,13 +11,14 @@ import numpy as np from typing import Union from collections.abc import Sequence, Iterable + def setitem_cmd(name: str, - inputs: list, - output_dtypes: list = None, - output_shapes: list = None, - attr_code: str = "", - attr_header: str = "", - outputs: list = None): + inputs: list, + output_dtypes: list = None, + output_shapes: list = None, + attr_code: str = "", + attr_header: str = "", + outputs: list = None): attr_header = "\nnamespace jittor{" + attr_header + "}\n" cuda_header = ''' @@ -51,6 +52,7 @@ def setitem_cmd(name: str, {attr_code} op.run();""") + def setitem_forward(name: str, inputs: list, output_dtypes: list = None, @@ -90,7 +92,8 @@ def setitem_forward(name: str, op.add(out0, false); {attr_code} op.run();""", - data=extra_data) + data=extra_data) + def caculate_shape(tensors): if isinstance(tensors, jt.Var): @@ -105,6 +108,7 @@ def caculate_shape(tensors): else: assert False, f"not implemented for {type(tensors)}" + def can_broadcast_and_shape(shape1, shape2): """ 检查两个张量是否可以广播,并返回广播后的形状。 @@ -144,6 +148,7 @@ def can_broadcast_and_shape(shape1, shape2): return True, tuple(broadcast_shape) + class SetItemACL(jt.Function): def __init__(self): @@ -180,9 +185,9 @@ class SetItemACL(jt.Function): op.jt_name = "inplacemaskedscatter"; """ result = setitem_cmd("InplaceMaskedScatter", - inputs=inputs, - outputs=outputs, - attr_code=attr_code)[0] + inputs=inputs, + outputs=outputs, + attr_code=attr_code)[0] return result # assert isinstance(value,jt.Var), "value must be jt.Var" @@ -199,7 +204,7 @@ class SetItemACL(jt.Function): contains_slice = False for s in slices: if not isinstance(s, jt.Var) and (isinstance(s, slice) - or s == Ellipsis): + or s == Ellipsis): contains_slice = True break if not contains_slice: @@ -220,9 +225,9 @@ class SetItemACL(jt.Function): self.value_shape = value_shape for ii in slices: indices.append(jt.Var(ii).int32()) - if isinstance(slices[0], jt.Var) or isinstance( - slices[0], int) or isinstance( - slices[0], list) or isinstance(slices[0], tuple): + if isinstance(slices[0], + jt.Var) or isinstance(slices[0], int) or isinstance( + slices[0], list) or isinstance(slices[0], tuple): self.indices = indices self.type_ = 'index' attr_code = f""" @@ -231,9 +236,9 @@ class SetItemACL(jt.Function): inputs = [value] + indices outputs = [x.clone()] result = setitem_cmd("IndexPutImpl", - inputs=inputs, - outputs=outputs, - attr_code=attr_code)[0] + inputs=inputs, + outputs=outputs, + attr_code=attr_code)[0] # result.sync() return result assert "not support" @@ -244,8 +249,7 @@ class SetItemACL(jt.Function): if not isinstance(s, jt.Var) and s == Ellipsis: slices = slices[:slices.index(s)] + [ slice(None, None, None) - ] * (x_dim - len(slices) + 1) + slices[slices.index(s) + - 1:] + ] * (x_dim - len(slices) + 1) + slices[slices.index(s) + 1:] break slices = tuple(slices) self.input_slice = slices @@ -335,10 +339,10 @@ class SetItemACL(jt.Function): inputs = [value] outputs = [x.clone()] result = setitem_forward("StridedSliceAssignV2", - inputs=inputs, - outputs=outputs, - attr_code=attr_code, - extra_data=extra_data)[0] + inputs=inputs, + outputs=outputs, + attr_code=attr_code, + extra_data=extra_data)[0] if expand_dim: result = result.squeeze(-1) # result.sync() @@ -349,4 +353,4 @@ class SetItemACL(jt.Function): if self.value_var: value_grad = grad_output[self.input_slice] grad_output[self.input_slice] = jt.zeros(self.value_shape) - return grad_output, None, value_grad \ No newline at end of file + return grad_output, None, value_grad diff --git a/python/jittor/extern/acl/aclops/setitem_op_acl.cc b/python/jittor/extern/acl/aclops/setitem_op_acl.cc index 7672eba2..1ed0e284 100644 --- a/python/jittor/extern/acl/aclops/setitem_op_acl.cc +++ b/python/jittor/extern/acl/aclops/setitem_op_acl.cc @@ -33,7 +33,7 @@ namespace jittor InplaceMaskedScatterOpRunner::InplaceMaskedScatterOpRunner() : BaseOpRunner("InplaceMaskedScatter") { } - + void InplaceMaskedScatterOpRunner::executeOp(std::unordered_map::iterator &it) { ret = aclnnInplaceMaskedScatterGetWorkspaceSize(outputTensors[0], inputTensors[0], inputTensors[1], &workspaceSize, &executor); @@ -55,7 +55,7 @@ namespace jittor IndexPutImplOpRunner::IndexPutImplOpRunner() : BaseOpRunner("IndexPutImpl") { } - + void IndexPutImplOpRunner::executeOp(std::unordered_map::iterator &it) { auto input_num = in_.size(); @@ -77,7 +77,7 @@ namespace jittor ret = aclnnIndexPutImpl(workspaceAddr, workspaceSize, executor, aclstream); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnIndexPutImpl failed. ERROR: %d\n", name.c_str(), ret); return); - // syncRun(); + syncRun(); return; } diff --git a/python/jittor/extern/acl/aclops/setitem_op_acl.h b/python/jittor/extern/acl/aclops/setitem_op_acl.h index a0dfd52d..ddd73902 100644 --- a/python/jittor/extern/acl/aclops/setitem_op_acl.h +++ b/python/jittor/extern/acl/aclops/setitem_op_acl.h @@ -9,6 +9,7 @@ namespace jittor protected: void executeOp(std::unordered_map::iterator &it) override; + public: InplaceMaskedScatterOpRunner(); }; @@ -18,6 +19,7 @@ namespace jittor protected: void executeOp(std::unordered_map::iterator &it) override; + public: IndexPutImplOpRunner(); }; diff --git a/python/jittor/extern/acl/aclops/sigmoid_op.py b/python/jittor/extern/acl/aclops/sigmoid_op.py index 49452482..ed3f1240 100644 --- a/python/jittor/extern/acl/aclops/sigmoid_op.py +++ b/python/jittor/extern/acl/aclops/sigmoid_op.py @@ -11,13 +11,14 @@ import numpy as np from typing import Union from collections.abc import Sequence, Iterable + def sigmoid_cmd(name: str, - inputs: list, - output_dtypes: list = None, - output_shapes: list = None, - attr_code: str = "", - attr_header: str = "", - outputs: list = None): + inputs: list, + output_dtypes: list = None, + output_shapes: list = None, + attr_code: str = "", + attr_header: str = "", + outputs: list = None): attr_header = "\nnamespace jittor{" + attr_header + "}\n" cuda_header = ''' @@ -51,6 +52,7 @@ def sigmoid_cmd(name: str, {attr_code} op.run();""") + class SigmoidACL(jt.Function): def __init__(self): @@ -64,9 +66,9 @@ class SigmoidACL(jt.Function): op.jt_name = "sigmoid"; """ result = sigmoid_cmd("Sigmoid", - inputs=inputs, - outputs=outputs, - attr_code=attr_code)[0] + inputs=inputs, + outputs=outputs, + attr_code=attr_code)[0] self.output = result return result @@ -77,7 +79,7 @@ class SigmoidACL(jt.Function): inputs = [grad_output, self.output] outputs = [jt.empty(grad_output.shape, grad_output.dtype)] grad_input = sigmoid_cmd("SigmoidBackward", - inputs=inputs, - outputs=outputs, - attr_code=attr_code)[0] - return grad_input \ No newline at end of file + inputs=inputs, + outputs=outputs, + attr_code=attr_code)[0] + return grad_input diff --git a/python/jittor/extern/acl/aclops/sigmoid_op_acl.h b/python/jittor/extern/acl/aclops/sigmoid_op_acl.h index bfc7191c..b175cd01 100644 --- a/python/jittor/extern/acl/aclops/sigmoid_op_acl.h +++ b/python/jittor/extern/acl/aclops/sigmoid_op_acl.h @@ -9,6 +9,7 @@ namespace jittor protected: void executeOp(std::unordered_map::iterator &it) override; + public: SigmoidOpRunner(); }; @@ -18,6 +19,7 @@ namespace jittor protected: void executeOp(std::unordered_map::iterator &it) override; + public: SigmoidBackwardOpRunner(); }; diff --git a/python/jittor/extern/acl/aclops/silu_op.py b/python/jittor/extern/acl/aclops/silu_op.py index 9cc44abb..30613b6a 100644 --- a/python/jittor/extern/acl/aclops/silu_op.py +++ b/python/jittor/extern/acl/aclops/silu_op.py @@ -11,13 +11,14 @@ import numpy as np from typing import Union from collections.abc import Sequence, Iterable + def silu_cmd(name: str, - inputs: list, - output_dtypes: list = None, - output_shapes: list = None, - attr_code: str = "", - attr_header: str = "", - outputs: list = None): + inputs: list, + output_dtypes: list = None, + output_shapes: list = None, + attr_code: str = "", + attr_header: str = "", + outputs: list = None): attr_header = "\nnamespace jittor{" + attr_header + "}\n" cuda_header = ''' @@ -51,6 +52,7 @@ def silu_cmd(name: str, {attr_code} op.run();""") + class SiLUACL(jt.Function): def __init__(self): @@ -65,9 +67,9 @@ class SiLUACL(jt.Function): op.jt_name = "silu"; """ result = silu_cmd("SiLU", - inputs=inputs, - outputs=outputs, - attr_code=attr_code)[0] + inputs=inputs, + outputs=outputs, + attr_code=attr_code)[0] return result def grad(self, grad_output): @@ -77,7 +79,7 @@ class SiLUACL(jt.Function): inputs = [grad_output, self.input] outputs = [jt.empty(grad_output.shape, grad_output.dtype)] grad_input = silu_cmd("SiLUBackward", - inputs=inputs, - outputs=outputs, - attr_code=attr_code)[0] - return grad_input \ No newline at end of file + inputs=inputs, + outputs=outputs, + attr_code=attr_code)[0] + return grad_input diff --git a/python/jittor/extern/acl/aclops/silu_op_acl.h b/python/jittor/extern/acl/aclops/silu_op_acl.h index cabbef29..abc52810 100644 --- a/python/jittor/extern/acl/aclops/silu_op_acl.h +++ b/python/jittor/extern/acl/aclops/silu_op_acl.h @@ -9,6 +9,7 @@ namespace jittor protected: void executeOp(std::unordered_map::iterator &it) override; + public: SiLUOpRunner(); }; @@ -18,6 +19,7 @@ namespace jittor protected: void executeOp(std::unordered_map::iterator &it) override; + public: SiLUBackwardOpRunner(); }; diff --git a/python/jittor/extern/acl/aclops/softmax_op.py b/python/jittor/extern/acl/aclops/softmax_op.py index 2c26a7eb..85ae9c72 100644 --- a/python/jittor/extern/acl/aclops/softmax_op.py +++ b/python/jittor/extern/acl/aclops/softmax_op.py @@ -11,13 +11,14 @@ import numpy as np from typing import Union from collections.abc import Sequence, Iterable + def softmax_cmd(name: str, - inputs: list, - output_dtypes: list = None, - output_shapes: list = None, - attr_code: str = "", - attr_header: str = "", - outputs: list = None): + inputs: list, + output_dtypes: list = None, + output_shapes: list = None, + attr_code: str = "", + attr_header: str = "", + outputs: list = None): attr_header = "\nnamespace jittor{" + attr_header + "}\n" cuda_header = ''' @@ -51,6 +52,7 @@ def softmax_cmd(name: str, {attr_code} op.run();""") + class SoftmaxACL(jt.Function): def __init__(self): @@ -68,9 +70,9 @@ class SoftmaxACL(jt.Function): op.op_attr.reset(attr); """ result = softmax_cmd("Softmax", - inputs=inputs, - outputs=outputs, - attr_code=attr_code)[0] + inputs=inputs, + outputs=outputs, + attr_code=attr_code)[0] self.output = result return result @@ -84,7 +86,7 @@ class SoftmaxACL(jt.Function): inputs = [grad_output, self.output] outputs = [jt.empty(grad_output.shape)] grad_input = softmax_cmd("SoftmaxBackward", - inputs=inputs, - outputs=outputs, - attr_code=attr_code)[0] - return grad_input \ No newline at end of file + inputs=inputs, + outputs=outputs, + attr_code=attr_code)[0] + return grad_input diff --git a/python/jittor/extern/acl/aclops/softmax_op_acl.h b/python/jittor/extern/acl/aclops/softmax_op_acl.h index 8132ee6d..11af9d36 100644 --- a/python/jittor/extern/acl/aclops/softmax_op_acl.h +++ b/python/jittor/extern/acl/aclops/softmax_op_acl.h @@ -9,6 +9,7 @@ namespace jittor protected: void executeOp(std::unordered_map::iterator &it) override; + public: SoftmaxOpRunner(); }; @@ -18,6 +19,7 @@ namespace jittor protected: void executeOp(std::unordered_map::iterator &it) override; + public: SoftmaxBackwardOpRunner(); }; diff --git a/python/jittor/extern/acl/aclops/stack_op.py b/python/jittor/extern/acl/aclops/stack_op.py index fa435e13..c9ba50b3 100644 --- a/python/jittor/extern/acl/aclops/stack_op.py +++ b/python/jittor/extern/acl/aclops/stack_op.py @@ -11,13 +11,14 @@ import numpy as np from typing import Union from collections.abc import Sequence, Iterable + def stack_cmd(name: str, - inputs: list, - output_dtypes: list = None, - output_shapes: list = None, - attr_code: str = "", - attr_header: str = "", - outputs: list = None): + inputs: list, + output_dtypes: list = None, + output_shapes: list = None, + attr_code: str = "", + attr_header: str = "", + outputs: list = None): attr_header = "\nnamespace jittor{" + attr_header + "}\n" cuda_header = ''' @@ -51,6 +52,7 @@ def stack_cmd(name: str, {attr_code} op.run();""") + class StackACL(jt.Function): def __init__(self): @@ -60,15 +62,12 @@ class StackACL(jt.Function): if type(input_tensors) is tuple: input_tensors = list(input_tensors) assert type(input_tensors) is list - assert -1 * len(input_tensors) - 1 <= dim and dim <= len( - input_tensors) + assert -1 * len(input_tensors) - 1 <= dim and dim <= len(input_tensors) for i in range(len(input_tensors)): if input_tensors[i].dtype != input_tensors[0].dtype: - raise ValueError( - "All input tensors must have the same dtype") + raise ValueError("All input tensors must have the same dtype") if input_tensors[i].shape != input_tensors[0].shape: - raise ValueError( - "All input tensors must have the same shape") + raise ValueError("All input tensors must have the same shape") self.input = input_tensors input_shape = list(input_tensors[0].shape) output_shape = input_shape[:dim] + [len(input_tensors) @@ -82,10 +81,10 @@ class StackACL(jt.Function): """ self.attr_code = attr_code result = stack_cmd("Stack", - input_tensors, - output_dtypes=[input_tensors[0].dtype], - output_shapes=[output_shape], - attr_code=self.attr_code)[0] + input_tensors, + output_dtypes=[input_tensors[0].dtype], + output_shapes=[output_shape], + attr_code=self.attr_code)[0] return result def grad(self, grad_output): @@ -110,7 +109,7 @@ class StackACL(jt.Function): """ result = stack_cmd("SplitWithSize", [grad_output], - output_dtypes=dtypeVec, - output_shapes=shapeVec, - attr_code=attr_code) - return result \ No newline at end of file + output_dtypes=dtypeVec, + output_shapes=shapeVec, + attr_code=attr_code) + return result diff --git a/python/jittor/extern/acl/aclops/stack_op_acl.h b/python/jittor/extern/acl/aclops/stack_op_acl.h index 758e1261..4b7df980 100644 --- a/python/jittor/extern/acl/aclops/stack_op_acl.h +++ b/python/jittor/extern/acl/aclops/stack_op_acl.h @@ -9,6 +9,7 @@ namespace jittor protected: void executeOp(std::unordered_map::iterator &it) override; + public: StackOpRunner(); }; diff --git a/python/jittor/extern/acl/aclops/ternary_op_acl.cc b/python/jittor/extern/acl/aclops/ternary_op_acl.cc index fc0d6fad..73f8fc4d 100644 --- a/python/jittor/extern/acl/aclops/ternary_op_acl.cc +++ b/python/jittor/extern/acl/aclops/ternary_op_acl.cc @@ -48,7 +48,7 @@ namespace jittor ret = aclnnSWhere(workspaceAddr, workspaceSize, executor, aclstream); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnxxx failed. ERROR: %d\n", name.c_str(), ret); return); - //syncRun(); + // syncRun(); return; } } \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/ternary_op_acl.h b/python/jittor/extern/acl/aclops/ternary_op_acl.h index 3ff3b28e..2402c039 100644 --- a/python/jittor/extern/acl/aclops/ternary_op_acl.h +++ b/python/jittor/extern/acl/aclops/ternary_op_acl.h @@ -7,7 +7,7 @@ namespace jittor struct TernaryOpRunner : public BaseOpRunner { TernaryOpRunner(); - + protected: void executeOp(std::unordered_map::iterator &it) override; }; diff --git a/python/jittor/extern/acl/aclops/transpose_op.py b/python/jittor/extern/acl/aclops/transpose_op.py index 63baf4ef..aa1d7e55 100644 --- a/python/jittor/extern/acl/aclops/transpose_op.py +++ b/python/jittor/extern/acl/aclops/transpose_op.py @@ -11,13 +11,14 @@ import numpy as np from typing import Union from collections.abc import Sequence, Iterable + def transpose_cmd(name: str, - inputs: list, - output_dtypes: list = None, - output_shapes: list = None, - attr_code: str = "", - attr_header: str = "", - outputs: list = None): + inputs: list, + output_dtypes: list = None, + output_shapes: list = None, + attr_code: str = "", + attr_header: str = "", + outputs: list = None): attr_header = "\nnamespace jittor{" + attr_header + "}\n" cuda_header = ''' @@ -51,6 +52,7 @@ def transpose_cmd(name: str, {attr_code} op.run();""") + class TransPoseACL(jt.Function): def __init__(self): @@ -75,9 +77,9 @@ class TransPoseACL(jt.Function): # calculate output shape output_shape = [x.shape[i] for i in dim] output = transpose_cmd("Transpose", [x], - output_dtypes=[x.dtype], - output_shapes=[jt.empty(output_shape).shape], - attr_code=attr_code)[0] + output_dtypes=[x.dtype], + output_shapes=[jt.empty(output_shape).shape], + attr_code=attr_code)[0] self.dim = dim return output @@ -93,7 +95,7 @@ class TransPoseACL(jt.Function): op.op_attr.reset(attr); """ output = transpose_cmd("Transpose", [grad_output], - output_dtypes=[grad_output.dtype], - output_shapes=[jt.empty(output_shape).shape], - attr_code=attr_code)[0] + output_dtypes=[grad_output.dtype], + output_shapes=[jt.empty(output_shape).shape], + attr_code=attr_code)[0] return output diff --git a/python/jittor/extern/acl/aclops/transpose_op_acl.h b/python/jittor/extern/acl/aclops/transpose_op_acl.h index 9a89ebbb..737fffd8 100644 --- a/python/jittor/extern/acl/aclops/transpose_op_acl.h +++ b/python/jittor/extern/acl/aclops/transpose_op_acl.h @@ -9,6 +9,7 @@ namespace jittor protected: void executeOp(std::unordered_map::iterator &it) override; + public: TransposeOpRunner(); }; diff --git a/python/jittor/extern/acl/aclops/unary_op_acl.cc b/python/jittor/extern/acl/aclops/unary_op_acl.cc index 4a489e50..d1172fce 100644 --- a/python/jittor/extern/acl/aclops/unary_op_acl.cc +++ b/python/jittor/extern/acl/aclops/unary_op_acl.cc @@ -53,7 +53,7 @@ namespace jittor ret = it->second.executeFunc(workspaceAddr, workspaceSize, executor, aclstream); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnxxx failed. ERROR: %d\n", name.c_str(), ret); return); - //syncRun(); + // syncRun(); return; } } \ No newline at end of file diff --git a/python/jittor/extern/acl/aclops/where_op.py b/python/jittor/extern/acl/aclops/where_op.py index 8115d699..f0417bb9 100644 --- a/python/jittor/extern/acl/aclops/where_op.py +++ b/python/jittor/extern/acl/aclops/where_op.py @@ -11,13 +11,14 @@ import numpy as np from typing import Union from collections.abc import Sequence, Iterable + def where_cmd(name: str, - inputs: list, - output_dtypes: list = None, - output_shapes: list = None, - attr_code: str = "", - attr_header: str = "", - outputs: list = None): + inputs: list, + output_dtypes: list = None, + output_shapes: list = None, + attr_code: str = "", + attr_header: str = "", + outputs: list = None): attr_header = "\nnamespace jittor{" + attr_header + "}\n" cuda_header = ''' @@ -51,6 +52,7 @@ def where_cmd(name: str, {attr_code} op.run();""") + class NonzeroACL(jt.Function): def __init__(self): @@ -63,15 +65,16 @@ class NonzeroACL(jt.Function): nonzero_cnt = (x != 0.0).sum().item() result = where_cmd("Nonzero", [x], - output_dtypes=['int64'], - output_shapes=[(nonzero_cnt, x.ndim)], - attr_code=attr_code)[0] + output_dtypes=['int64'], + output_shapes=[(nonzero_cnt, x.ndim)], + attr_code=attr_code)[0] return result def grad(self, grad_output): return grad_output + class WhereACL(jt.Function): def __init__(self): @@ -104,9 +107,9 @@ class WhereACL(jt.Function): self.y = y result = where_cmd("Where", [condition, x, y], - output_dtypes=[x.dtype], - output_shapes=[x.shape], - attr_code="op.jt_name=\"where\";")[0] + output_dtypes=[x.dtype], + output_shapes=[x.shape], + attr_code="op.jt_name=\"where\";")[0] return result def grad(self, grad_output): @@ -115,12 +118,12 @@ class WhereACL(jt.Function): else: tmp = jt.zeros(grad_output.shape, dtype=grad_output.dtype) grad_x = where_cmd("Where", [self.condition, grad_output, tmp], - output_dtypes=[self.x.dtype], - output_shapes=[self.x.shape], - attr_code="op.jt_name=\"where\";")[0] + output_dtypes=[self.x.dtype], + output_shapes=[self.x.shape], + attr_code="op.jt_name=\"where\";")[0] grad_y = where_cmd("Where", [self.condition, tmp, grad_output], - output_dtypes=[self.y.dtype], - output_shapes=[self.y.shape], - attr_code="op.jt_name=\"where\";")[0] + output_dtypes=[self.y.dtype], + output_shapes=[self.y.shape], + attr_code="op.jt_name=\"where\";")[0] return grad_output, grad_x, grad_y diff --git a/python/jittor/extern/acl/aclops/where_op_acl.cc b/python/jittor/extern/acl/aclops/where_op_acl.cc index f2b44057..a1d2ddc8 100644 --- a/python/jittor/extern/acl/aclops/where_op_acl.cc +++ b/python/jittor/extern/acl/aclops/where_op_acl.cc @@ -53,7 +53,6 @@ namespace jittor return; } - NonzeroOpRunner::NonzeroOpRunner() : BaseOpRunner("Nonzero") { } diff --git a/python/jittor/extern/acl/aclops/where_op_acl.h b/python/jittor/extern/acl/aclops/where_op_acl.h index 37d0d227..d881f752 100644 --- a/python/jittor/extern/acl/aclops/where_op_acl.h +++ b/python/jittor/extern/acl/aclops/where_op_acl.h @@ -9,6 +9,7 @@ namespace jittor protected: void executeOp(std::unordered_map::iterator &it) override; + public: WhereOpRunner(); }; @@ -18,6 +19,7 @@ namespace jittor protected: void executeOp(std::unordered_map::iterator &it) override; + public: NonzeroOpRunner(); };