mirror of https://github.com/Jittor/Jittor
2821 lines
112 KiB
Python
2821 lines
112 KiB
Python
# ***************************************************************
|
||
# Copyright (c) 2023 Jittor. All Rights Reserved.
|
||
# Maintainers: Dun Liang <randonlang@gmail.com>.
|
||
# This file is subject to the terms and conditions defined in
|
||
# file 'LICENSE.txt', which is part of this source code package.
|
||
# ***************************************************************
|
||
import os
|
||
from jittor_utils import env_or_try_find
|
||
import jittor_utils
|
||
import ctypes
|
||
import glob
|
||
import jittor.compiler as compiler
|
||
import jittor as jt
|
||
import math
|
||
import numpy as np
|
||
|
||
from typing import Union
|
||
from collections.abc import Sequence, Iterable
|
||
|
||
|
||
def _ntuple(n):
|
||
|
||
def parse(x):
|
||
if isinstance(x, Iterable):
|
||
return x
|
||
return tuple([x] * n)
|
||
|
||
return parse
|
||
|
||
|
||
_pair = _ntuple(2)
|
||
|
||
has_acl = 0
|
||
cc_flags = ""
|
||
tikcc_path = env_or_try_find('tikcc_path', 'ccec')
|
||
dlopen_flags = os.RTLD_NOW | os.RTLD_GLOBAL
|
||
compiler.has_acl = has_acl
|
||
|
||
# export LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/latest/tools/aoe/lib64:/usr/local/Ascend/ascend-toolkit/latest/compiler/lib64:/usr/local/Ascend/ascend-toolkit/latest/compiler/lib64/plugin/opskernel:/usr/local/Ascend/ascend-toolkit/latest/compiler/lib64/plugin/nnengine:/usr/local/Ascend/ascend-toolkit/latest/runtime/lib64:/usr/local/Ascend/ascend-toolkit/latest/compiler/lib64/stub:/usr/local/Ascend/ascend-toolkit/latest/tools/tikicpulib/lib/Ascend910A:/usr/local/Ascend/ascend-toolkit/latest/toolkit/tools/simulator/Ascend910A/lib:/opt/AXESMI/lib64:/usr/local/Ascend/driver/lib64/driver/
|
||
# export PYTHONPATH=/home/cjld/new_jittor/jittor/python
|
||
# export tikcc_path=g++
|
||
|
||
# conda activate cann
|
||
# source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||
# export PYTHONPATH=/home/cjld/new_jittor/jittor/python:/home/cjld/new_jittor/jittor/my/jtorch/python:$PYTHONPATH
|
||
# export TASK_QUEUE_ENABLE=0
|
||
# python3 -m jittor.test.test_acl -k array
|
||
# jittor: conda activate cann && source /usr/local/Ascend/ascend-toolkit/set_env.sh && PYTHONPATH=/home/cjld/new_jittor/jittor/python:/home/cjld/new_jittor/jittor/my/jtorch/python:$PYTHONPATH && cd /home/cjld/new_jittor/jittor/my/mm_benchmark
|
||
# python3 -m jittor.test.test_acl -k test_sum
|
||
# export ASCEND_SLOG_PRINT_TO_STDOUT=0
|
||
# ASCEND_GLOBAL_LOG_LEVEL
|
||
# export DUMP_GE_GRAPH=1
|
||
# export DUMP_GRAPH_LEVEL=1
|
||
|
||
# build pytorch-npu
|
||
# bash ./ci/build.sh
|
||
# python3 -m pip install ./dist/torch_npu-1.11.0.post1-cp37-cp37m-linux_x86_64.whl --force-reinstall
|
||
# pytorch: conda activate cann && source /usr/local/Ascend/ascend-toolkit/set_env.sh && export TASK_QUEUE_ENABLE=0 && cd /home/cjld/new_jittor/jittor/my/mm_benchmark
|
||
# python3 ./mm_bench_pt_npu.py
|
||
|
||
|
||
def install():
|
||
import jittor.compiler as compiler
|
||
global has_acl, cc_flags
|
||
acl_compiler_home = os.path.dirname(__file__)
|
||
cc_files = sorted(glob.glob(acl_compiler_home + "/**/*.cc",
|
||
recursive=True))
|
||
cc_files2 = []
|
||
for name in cc_files:
|
||
# if "acl_op_exec" in name or "_op_acl.cc" in name:
|
||
if "acl_op_exec" in name or "_op_acl.cc" in name or "utils.cc" in name:
|
||
compiler.extra_core_files.append(name)
|
||
else:
|
||
cc_files2.append(name)
|
||
cc_files = cc_files2
|
||
ascend_toolkit_home = os.getenv('ASCEND_TOOLKIT_HOME')
|
||
|
||
#print(ascend_toolkit_home)
|
||
#print(acl_compiler_home)
|
||
cc_flags += f" -MD -DHAS_CUDA -DIS_ACL \
|
||
-I{ascend_toolkit_home}/include/ \
|
||
-I{ascend_toolkit_home}/include/acl/ \
|
||
-I{ascend_toolkit_home}/include/aclnn/ \
|
||
-I{ascend_toolkit_home}/include/aclnnop/ \
|
||
-I{acl_compiler_home} -lascendcl -lacl_op_compiler \
|
||
-I{acl_compiler_home}/aclnn \
|
||
-I{acl_compiler_home}/aclops \
|
||
-L{ascend_toolkit_home}/lib64/"
|
||
|
||
cc_flags += " -llibascendcl "
|
||
cc_flags += " -llibnnopbase "
|
||
cc_flags += " -llibopapi "
|
||
|
||
#pdb.set_trace()
|
||
ctypes.CDLL("libascendcl.so", dlopen_flags)
|
||
f'''
|
||
-ltikc_runtime
|
||
-I/usr/local/Ascend/driver/include/ \
|
||
-L{ascend_toolkit_home}/compiler/lib64/ \
|
||
-L{ascend_toolkit_home}/runtime/lib64/ \
|
||
'''
|
||
jittor_utils.LOG.i("ACL detected")
|
||
|
||
global mod
|
||
mod = jittor_utils.compile_module(
|
||
'''
|
||
#include "common.h"
|
||
namespace jittor {
|
||
// @pyjt(process)
|
||
string process_acl(const string& src, const string& name, const map<string,string>& kargs);
|
||
// @pyjt(init_acl_ops)
|
||
void init_acl_ops();
|
||
}''', compiler.cc_flags + " " + " ".join(cc_files) + cc_flags)
|
||
jittor_utils.process_jittor_source("acl", mod.process)
|
||
|
||
has_acl = 1
|
||
os.environ["use_mkl"] = "0"
|
||
compiler.setup_fake_cuda_lib = True
|
||
|
||
|
||
def install_extern():
|
||
return False
|
||
|
||
|
||
def check():
|
||
import jittor.compiler as compiler
|
||
global has_acl, cc_flags
|
||
if tikcc_path:
|
||
try:
|
||
install()
|
||
except Exception as e:
|
||
jittor_utils.LOG.w(f"load ACL failed, exception: {e}")
|
||
has_acl = 0
|
||
compiler.has_acl = has_acl
|
||
compiler.tikcc_path = tikcc_path
|
||
if not has_acl: return False
|
||
compiler.cc_flags += cc_flags
|
||
compiler.nvcc_path = tikcc_path
|
||
compiler.nvcc_flags = compiler.cc_flags.replace("-std=c++14", "")
|
||
return True
|
||
|
||
|
||
def post_process():
|
||
if has_acl:
|
||
from jittor import pool
|
||
pool.pool_use_code_op = False
|
||
import jittor as jt
|
||
jt.flags.use_cuda_host_allocator = 1
|
||
jt.flags.use_parallel_op_compiler = 0
|
||
jt.flags.amp_reg |= 32 + 4 # 32 keep float16, 4 keep reduce type
|
||
mod.init_acl_ops()
|
||
|
||
|
||
def acl_cmd(name: str,
|
||
inputs: list,
|
||
output_dtypes: list = None,
|
||
output_shapes: list = None,
|
||
attr_code: str = "",
|
||
attr_header: str = "",
|
||
outputs: list = None):
|
||
# inputs: list,
|
||
# output_dtypes: list,
|
||
# output_shapes: list,
|
||
# attr_code: str = ""):
|
||
# input_code = ''
|
||
# for i in range(len(inputs)):
|
||
# input_code += f"op.add(in{i}, true);\n"
|
||
|
||
# output_code = ''
|
||
# for i in range(len(output_dtypes)):
|
||
# output_code += f"op.add(out{i}, false);\n"
|
||
|
||
# # read the tmp_file.cpp to the cuda_header
|
||
# with open(
|
||
# "/home/ma-user/work/zy/JittorHW/python/jittor/extern/acl/tmp_file.cpp",
|
||
# "r") as f:
|
||
# cuda_header = f.read()
|
||
# import jittor as jt
|
||
# return jt.code(output_shapes,
|
||
# output_dtypes,
|
||
# inputs,
|
||
# cuda_header=cuda_header,
|
||
# cuda_src=f"""
|
||
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
||
|
||
# read the tmp_file.cpp to the cuda_header
|
||
|
||
cuda_header = '#include "acl/aclops/aclops.h"'
|
||
import jittor as jt
|
||
outputs_ = []
|
||
if outputs is not None:
|
||
outputs_ = outputs
|
||
else:
|
||
assert output_dtypes is not None
|
||
assert output_shapes is not None
|
||
assert len(output_dtypes) == len(output_shapes)
|
||
# print(f'{name } output_dtypes', output_dtypes)
|
||
# print(f'{name } output_shapes', output_shapes)
|
||
for i in range(len(output_shapes)):
|
||
outputs_.append(jt.empty(output_shapes[i], output_dtypes[i]))
|
||
# print(f'{name } outputs_', outputs_)
|
||
input_code = ''
|
||
for i in range(len(inputs)):
|
||
input_code += f"op.add(in{i}, true);\n"
|
||
|
||
output_code = ''
|
||
for i in range(len(outputs_)):
|
||
output_code += f"op.add(out{i}, false);\n"
|
||
return jt.code(outputs=outputs_,
|
||
inputs=inputs,
|
||
cuda_header=attr_header + cuda_header,
|
||
cuda_src=f"""
|
||
|
||
// aclop
|
||
AclOpRunner op("{name}");
|
||
{input_code}
|
||
{output_code}
|
||
{attr_code}
|
||
op.run();""")
|
||
|
||
|
||
def acl_cmd_forward(name: str,
|
||
inputs: list,
|
||
output_dtypes: list = None,
|
||
output_shapes: list = None,
|
||
attr_code: str = "",
|
||
attr_header: str = "",
|
||
outputs: list = None,
|
||
extra_data: dict = {}):
|
||
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
||
|
||
cuda_header = '''
|
||
#include "acl/aclops/aclops.h"
|
||
'''
|
||
import jittor as jt
|
||
outputs_ = []
|
||
if outputs is not None:
|
||
outputs_ = outputs
|
||
else:
|
||
assert output_dtypes is not None
|
||
assert output_shapes is not None
|
||
assert len(output_dtypes) == len(output_shapes)
|
||
# print(f'{name } output_dtypes', output_dtypes)
|
||
# print(f'{name } output_shapes', output_shapes)
|
||
for i in range(len(output_shapes)):
|
||
outputs_.append(jt.empty(output_shapes[i], output_dtypes[i]))
|
||
# print(f'{name } outputs_', outputs_)
|
||
input_code = ''
|
||
for i in range(len(inputs)):
|
||
input_code += f"op.add(in{i}, true);\n"
|
||
|
||
output_code = ''
|
||
for i in range(len(outputs_)):
|
||
output_code += f"op.add(out{i}, false);\n"
|
||
|
||
return jt.code(outputs=outputs_,
|
||
inputs=inputs,
|
||
cuda_header=attr_header + cuda_header,
|
||
cuda_src=f"""
|
||
// aclop
|
||
AclOpRunner op("{name}");
|
||
{input_code}
|
||
{output_code}
|
||
{attr_code}
|
||
op.run();""",
|
||
data=extra_data)
|
||
|
||
|
||
def change_function():
|
||
import jittor as jt
|
||
from jittor import Function
|
||
|
||
class TriuACL(Function):
|
||
|
||
def __init__(self):
|
||
super(TriuACL, self).__init__()
|
||
|
||
def execute(self, input, diagonal):
|
||
attr_code = f"""
|
||
op.jt_name = "triu";
|
||
TriuAttr *attr = new TriuAttr();
|
||
attr->diagonal = {diagonal};
|
||
op.op_attr.reset(attr);
|
||
"""
|
||
|
||
result = acl_cmd("Triu", [input],
|
||
output_dtypes=[input.dtype],
|
||
output_shapes=[input.shape],
|
||
attr_code=attr_code)[0]
|
||
return result
|
||
|
||
def grad(self, grad_output):
|
||
return grad_output
|
||
|
||
def triu_acl(x, diagonal=0):
|
||
return TriuACL()(x, diagonal)
|
||
|
||
# class ConvACL(Function):
|
||
|
||
# def execute(self,
|
||
# x,
|
||
# weight,
|
||
# bias=None,
|
||
# stride=1,
|
||
# padding=0,
|
||
# dilation=1,
|
||
# groups=1):
|
||
# self.input = x
|
||
# self.weight = weight
|
||
# self.bias = bias
|
||
# padding = _pair(padding)
|
||
# stride = _pair(stride)
|
||
# dilation = _pair(dilation)
|
||
# out_channels = weight.shape[0]
|
||
# if groups <= 0:
|
||
# raise ValueError("groups must be a positive integer")
|
||
# self.padding = padding
|
||
# self.stride = stride
|
||
# self.dilation = dilation
|
||
# self.groups = groups
|
||
# attr_code = f"""
|
||
# op.jt_name = "conv2d";
|
||
# ConvAttr *attr = new ConvAttr();
|
||
# attr->convStrides = {{ {stride[0]}, {stride[1]} }};
|
||
# attr->convPads = {{ {padding[0]}, {padding[1]} }};
|
||
# attr->convDilations = {{ {dilation[0]}, {dilation[1]} }};
|
||
# attr->group = {groups};
|
||
# attr->convOutPads = {{ 1,1}};
|
||
# op.op_attr.reset(attr);
|
||
# """
|
||
# input_height, input_width = x.shape[-2:]
|
||
# kernel_height, kernel_width = weight.shape[-2:]
|
||
|
||
# output_height = (input_height + 2 * padding[0] - dilation[0] *
|
||
# (kernel_height - 1) - 1) // stride[0] + 1
|
||
# output_width = (input_width + 2 * padding[1] - dilation[1] *
|
||
# (kernel_width - 1) - 1) // stride[1] + 1
|
||
|
||
# output_shape = (x.shape[0], out_channels, output_height,
|
||
# output_width)
|
||
|
||
# inputs = [x, weight]
|
||
# if bias is not None:
|
||
# inputs.append(bias)
|
||
# result = acl_cmd(
|
||
# "Conv2d",
|
||
# inputs,
|
||
# output_dtypes=[x.dtype],
|
||
# output_shapes=[output_shape],
|
||
# attr_code=attr_code,
|
||
# )[0]
|
||
# return result
|
||
|
||
# def grad(self, grad_output):
|
||
# x = self.input
|
||
# weight = self.weight
|
||
# bias = self.bias
|
||
# inputs = [grad_output, x, weight]
|
||
# if bias is not None:
|
||
# inputs.append(bias)
|
||
# output_shapes = [x.shape, weight.shape]
|
||
# output_dtypes = [x.dtype, weight.dtype]
|
||
# if bias is not None:
|
||
# output_shapes.append(bias.shape)
|
||
# output_dtypes.append(bias.dtype)
|
||
# else:
|
||
# output_shapes.append([1])
|
||
# output_dtypes.append(x.dtype)
|
||
# padding = self.padding
|
||
# stride = self.stride
|
||
# dilation = self.dilation
|
||
# groups = self.groups
|
||
# attr_code = f"""
|
||
# op.jt_name = "conv2dbackward";
|
||
# ConvAttr *attr = new ConvAttr();
|
||
# attr->convStrides = {{ {stride[0]}, {stride[1]} }};
|
||
# attr->convPads = {{ {padding[0]}, {padding[1]} }};
|
||
# attr->convDilations = {{ {dilation[0]}, {dilation[1]} }};
|
||
# attr->group = {groups};
|
||
# attr->convOutPads = {{ 1,1}};
|
||
# op.op_attr.reset(attr);
|
||
# """
|
||
# results = acl_cmd("Conv2dBackward",
|
||
# inputs,
|
||
# output_dtypes=output_dtypes,
|
||
# output_shapes=output_shapes,
|
||
# attr_code=attr_code)
|
||
# if self.bias is None:
|
||
# return results[0], results[1]
|
||
|
||
# return results
|
||
|
||
from .aclops.conv_op import ConvACL
|
||
def conv_acl(x,
|
||
weight,
|
||
bias=None,
|
||
stride=1,
|
||
padding=0,
|
||
dilation=1,
|
||
groups=1):
|
||
return ConvACL()(x, weight, bias, stride, padding, dilation, groups)
|
||
|
||
class Conv2D(jt.nn.Module):
|
||
|
||
def __init__(self,
|
||
in_channels,
|
||
out_channels,
|
||
kernel_size,
|
||
stride=1,
|
||
padding=0,
|
||
dilation=1,
|
||
groups=1,
|
||
bias=True):
|
||
if in_channels <= 0:
|
||
raise ValueError(
|
||
f"in_channels must be greater than zero, got {in_channels}"
|
||
)
|
||
if out_channels <= 0:
|
||
raise ValueError(
|
||
f"out_channels must be greater than zero, got {out_channels}"
|
||
)
|
||
if groups <= 0:
|
||
raise ValueError(
|
||
f"groups must must be greater than zero, got {groups}")
|
||
assert in_channels % groups == 0, 'in_channels must be divisible by groups'
|
||
assert out_channels % groups == 0, 'out_channels must be divisible by groups'
|
||
if isinstance(kernel_size, tuple):
|
||
for size in kernel_size:
|
||
if size <= 0:
|
||
raise ValueError(
|
||
f"kernel_size must be greater than zero, got {kernel_size}"
|
||
)
|
||
else:
|
||
if kernel_size <= 0:
|
||
raise ValueError(
|
||
f"kernel_size must be greater than zero, got {kernel_size}"
|
||
)
|
||
if isinstance(stride, tuple):
|
||
for size in stride:
|
||
if size <= 0:
|
||
raise ValueError(
|
||
f"stride must be greater than zero, got {stride}")
|
||
else:
|
||
if stride <= 0:
|
||
raise ValueError(
|
||
f"stride must be greater than zero, got {stride}")
|
||
if isinstance(padding, tuple):
|
||
for size in padding:
|
||
if size < 0:
|
||
raise ValueError(
|
||
f"padding must be nonnegative, got {padding}")
|
||
else:
|
||
if padding < 0:
|
||
raise ValueError(
|
||
f"padding must be nonnegative, got {padding}")
|
||
if isinstance(dilation, tuple):
|
||
for size in dilation:
|
||
if size <= 0:
|
||
raise ValueError(
|
||
f"dilation must be greater than zero, got {dilation}"
|
||
)
|
||
else:
|
||
if dilation <= 0:
|
||
raise ValueError(
|
||
f"dilation must be greater than zero, got {dilation}")
|
||
self.in_channels = in_channels
|
||
self.out_channels = out_channels
|
||
self.kernel_size = kernel_size if isinstance(
|
||
kernel_size, tuple) else (kernel_size, kernel_size)
|
||
self.stride = stride if isinstance(stride, tuple) else (stride,
|
||
stride)
|
||
self.padding = padding if isinstance(padding, tuple) else (padding,
|
||
padding)
|
||
self.dilation = dilation if isinstance(
|
||
dilation, tuple) else (dilation, dilation)
|
||
self.groups = groups
|
||
self.is_depthwise_conv = self.groups == self.out_channels and self.groups == self.in_channels
|
||
if self.is_depthwise_conv and jt.flags.use_cuda and jt.compiler.is_cuda:
|
||
self.depthwise_conv = jt.nn.DepthwiseConv(
|
||
stride, padding, dilation)
|
||
Kh, Kw = self.kernel_size
|
||
|
||
# self.weight = init.relu_invariant_gauss([out_channels, in_channels//groups, Kh, Kw], dtype="float", mode="fan_out")
|
||
self.weight = jt.init.invariant_uniform(
|
||
[out_channels, in_channels // groups, Kh, Kw], dtype="float")
|
||
if bias:
|
||
fan = 1
|
||
for i in self.weight.shape[1:]:
|
||
fan *= i
|
||
bound = 1 / math.sqrt(fan)
|
||
self.bias = jt.init.uniform([out_channels],
|
||
dtype="float",
|
||
low=-bound,
|
||
high=bound)
|
||
else:
|
||
self.bias = None
|
||
|
||
def execute(self, x):
|
||
ret = jt.nn.conv2d(x, self.weight, self.bias, self.stride,
|
||
self.padding, self.dilation, self.groups)
|
||
return ret
|
||
|
||
# class PoolACL(Function):
|
||
|
||
# def __init__(self,
|
||
# kernel_size,
|
||
# stride=None,
|
||
# padding=0,
|
||
# dilation=None,
|
||
# return_indices=None,
|
||
# ceil_mode=False,
|
||
# count_include_pad=True,
|
||
# op='maximum'):
|
||
# self.kernel_size = kernel_size if isinstance(
|
||
# kernel_size, tuple) else (kernel_size, kernel_size)
|
||
# stride = stride if stride else kernel_size
|
||
# self.stride = stride if isinstance(stride, tuple) else (stride,
|
||
# stride)
|
||
# self.padding = padding if isinstance(padding, tuple) else (padding,
|
||
# padding)
|
||
# dilation = dilation if dilation else 1
|
||
# assert dilation == 1
|
||
# self.dilation = dilation if isinstance(
|
||
# dilation, tuple) else (dilation, dilation)
|
||
# for item in self.kernel_size:
|
||
# if item <= 0:
|
||
# raise RuntimeError(
|
||
# f"kernel_size must be greater than zero, but got {item}"
|
||
# )
|
||
# for item in self.stride:
|
||
# if item <= 0:
|
||
# raise RuntimeError(
|
||
# f"stride must be greater than zero, but got {item}")
|
||
# for item in self.padding:
|
||
# if item < 0:
|
||
# raise RuntimeError(
|
||
# f"padding must be non-negative, but got {item}")
|
||
# self.op = op
|
||
# self.return_indices = return_indices
|
||
# self.ceil_mode = ceil_mode
|
||
# self.count_include_pad = count_include_pad
|
||
|
||
# def execute(self, input):
|
||
# self.input = input
|
||
# attr_code = f"""
|
||
# op.jt_name = "{"avgpool" if self.op == 'mean' else "maxpool"}";
|
||
# PoolAttr *attr = new PoolAttr();
|
||
# attr->kernel_size = {{ {self.kernel_size[0]}, {self.kernel_size[1]} }};
|
||
# attr->poolStrides = {{ {self.stride[0]}, {self.stride[1]} }};
|
||
# attr->poolPads = {{ {self.padding[0]}, {self.padding[1]} }};
|
||
# attr->poolDilations = {{ {self.dilation[0]}, {self.dilation[1]} }};
|
||
# attr->poolCeil = {"true" if self.ceil_mode else "false"};
|
||
# attr->countIncludePad = {"true" if self.count_include_pad else "false"};
|
||
# op.op_attr.reset(attr);
|
||
# """
|
||
# input_height, input_width = input.shape[-2:]
|
||
# kernel_height, kernel_width = self.kernel_size[-2:]
|
||
|
||
# output_height = (input_height + 2 * self.padding[0] -
|
||
# (kernel_height - 1) - 1) // self.stride[0] + 1
|
||
# output_width = (input_width + 2 * self.padding[1] -
|
||
# (kernel_width - 1) - 1) // self.stride[1] + 1
|
||
|
||
# output_shape = (input.shape[0], input.shape[1], output_height,
|
||
# output_width)
|
||
|
||
# inputs = [input]
|
||
|
||
# if self.op == 'maximum':
|
||
# result = acl_cmd(
|
||
# "Maxpool",
|
||
# inputs,
|
||
# output_dtypes=[input.dtype, 'int32'],
|
||
# output_shapes=[output_shape, output_shape],
|
||
# attr_code=attr_code,
|
||
# )
|
||
# elif self.op == 'mean':
|
||
# result = acl_cmd(
|
||
# "Avgpool",
|
||
# inputs,
|
||
# output_dtypes=[input.dtype],
|
||
# output_shapes=[output_shape],
|
||
# attr_code=attr_code,
|
||
# )
|
||
# else:
|
||
# raise ValueError('no this type pool')
|
||
|
||
# if self.op == 'maximum':
|
||
# self.index = result[1]
|
||
|
||
# if self.return_indices:
|
||
# return result[0], result[1]
|
||
# else:
|
||
# return result[0]
|
||
|
||
# def grad(self, grad_output):
|
||
# input = self.input
|
||
# attr_code = f"""
|
||
# op.jt_name = "{"avgpoolbackward" if self.op == 'mean' else "maxpoolbackward"}";
|
||
# PoolAttr *attr = new PoolAttr();
|
||
# attr->kernel_size = {{ {self.kernel_size[0]}, {self.kernel_size[1]} }};
|
||
# attr->poolStrides = {{ {self.stride[0]}, {self.stride[1]} }};
|
||
# attr->poolPads = {{ {self.padding[0]}, {self.padding[1]} }};
|
||
# attr->poolDilations = {{ {self.dilation[0]}, {self.dilation[1]} }};
|
||
# attr->poolCeil = {"true" if self.ceil_mode else "false"};
|
||
# attr->countIncludePad = {"true" if self.count_include_pad else "false"};
|
||
# op.op_attr.reset(attr);
|
||
# """
|
||
# output_shapes = [input.shape]
|
||
# output_dtypes = [input.dtype]
|
||
# if self.op == 'maximum':
|
||
# result = acl_cmd("MaxpoolBackward",
|
||
# inputs=[grad_output, input, self.index],
|
||
# output_dtypes=output_dtypes,
|
||
# output_shapes=output_shapes,
|
||
# attr_code=attr_code)[0]
|
||
# elif self.op == 'mean':
|
||
# result = acl_cmd("AvgpoolBackward",
|
||
# inputs=[grad_output, input],
|
||
# output_dtypes=output_dtypes,
|
||
# output_shapes=output_shapes,
|
||
# attr_code=attr_code)[0]
|
||
# else:
|
||
# raise ValueError('no this type pool')
|
||
# return result
|
||
|
||
# class FlipACL(Function):
|
||
|
||
# def __init__(self):
|
||
# super(FlipACL, self).__init__()
|
||
|
||
# def execute(self, input, dim):
|
||
# if type(dim) is tuple:
|
||
# dim = list(dim)
|
||
# if type(dim) is not list:
|
||
# dim = [dim]
|
||
# attr_code = f"""
|
||
# op.jt_name = "flip";
|
||
# ReduceAttr *attr = new ReduceAttr();
|
||
# attr->axes = {{{', '.join(map(str, (list(dim))))}}};
|
||
# attr->prod_dim = {len(dim)};
|
||
# op.op_attr.reset(attr);
|
||
# """
|
||
# self.attr_code = attr_code
|
||
# result = acl_cmd("Flip", [input],
|
||
# output_dtypes=[input.dtype],
|
||
# output_shapes=[input.shape],
|
||
# attr_code=self.attr_code)[0]
|
||
# return result
|
||
|
||
# def grad(self, grad_output):
|
||
# grad_input = acl_cmd("Flip", [grad_output],
|
||
# output_dtypes=[grad_output.dtype],
|
||
# output_shapes=[grad_output.shape],
|
||
# attr_code=self.attr_code)[0]
|
||
# return grad_input
|
||
|
||
from .aclops.flip_op import FlipACL
|
||
def flip_acl(x, dim):
|
||
return FlipACL()(x, dim)
|
||
|
||
# class ConcatACL(Function):
|
||
|
||
# def __init__(self):
|
||
# super(ConcatACL, self).__init__()
|
||
|
||
# def __call__(self, *args):
|
||
# assert isinstance(args[0], list)
|
||
# assert isinstance(args[1], int)
|
||
# if jt.flags.no_grad:
|
||
# return self.execute(*args)
|
||
# backup = args
|
||
# args = list(args)
|
||
# taped_inputs = []
|
||
# taped_outputs = []
|
||
# input_mask = [-1] * (len(args[0]) + 1)
|
||
# newargs = [list(), args[1]]
|
||
# for i, v in enumerate(args[0]):
|
||
# if isinstance(v, jt.Var):
|
||
# if v.is_stop_grad():
|
||
# # -2 in input_mask represents it is stop_grad
|
||
# input_mask[i] = -2
|
||
# newargs[0].append(v)
|
||
# continue
|
||
# v = v.tape()
|
||
# newargs[0].append(v)
|
||
# input_mask[i] = len(taped_inputs)
|
||
# taped_inputs.append(v)
|
||
|
||
# ori_res = self.execute(*newargs)
|
||
# if not isinstance(ori_res, Sequence):
|
||
# res = [ori_res]
|
||
# else:
|
||
# res = list(ori_res)
|
||
# output_mask = [-1] * len(res)
|
||
# for i, v in enumerate(res):
|
||
# if isinstance(v, jt.Var):
|
||
# v = v.tape()
|
||
# output_mask[i] = len(taped_outputs)
|
||
# res[i] = v
|
||
# taped_outputs.append(v)
|
||
# self.input_mask = input_mask
|
||
# self.output_mask = output_mask
|
||
# # tape output and input together so
|
||
# # backward treat them as one operator
|
||
# jt.tape_together(taped_inputs, taped_outputs, self._grad)
|
||
# if isinstance(ori_res, Sequence):
|
||
# return res
|
||
# else:
|
||
# return res[0]
|
||
|
||
# def execute(self, input_tensors, dim=0):
|
||
# for _ in input_tensors:
|
||
# if not (-_.ndim <= dim < _.ndim):
|
||
# print(_.shape, dim)
|
||
# raise ValueError("dim out of range")
|
||
|
||
# if dim < 0:
|
||
# dim += input_tensors[0].ndim
|
||
|
||
# self.input = input_tensors
|
||
# self.dim = dim
|
||
# for i in range(len(input_tensors)):
|
||
# if input_tensors[i].dtype != input_tensors[0].dtype:
|
||
# raise ValueError(
|
||
# "All input tensors must have the same dtype")
|
||
# if input_tensors[i].shape[:dim] != input_tensors[
|
||
# 0].shape[:dim] or input_tensors[i].shape[
|
||
# dim + 1:] != input_tensors[0].shape[dim + 1:]:
|
||
# raise ValueError(
|
||
# "All input tensors must have the same shape")
|
||
# attr_code = f"""
|
||
# op.jt_name = "concat";
|
||
# ConcatAttr *attr = new ConcatAttr();
|
||
# attr->tensorNum = {len(input_tensors)};
|
||
# attr->dim = {dim};
|
||
# op.op_attr.reset(attr);
|
||
# """
|
||
# result = acl_cmd(
|
||
# "Concat",
|
||
# input_tensors,
|
||
# output_dtypes=[input_tensors[0].dtype],
|
||
# output_shapes=[
|
||
# jt.empty(self.calculate_output_shape(input_tensors,
|
||
# dim)).shape
|
||
# ],
|
||
# attr_code=attr_code)[0]
|
||
# return result
|
||
|
||
# def _grad(self, *args):
|
||
# new_args = ((args[i] if i >= 0 else None)
|
||
# for i in self.output_mask)
|
||
# ret = self.grad(*new_args)
|
||
# new_ret = []
|
||
# for i, r in enumerate(ret):
|
||
# j = self.input_mask[i]
|
||
# if j < 0:
|
||
# # -2 in input_mask represents it is stop_grad
|
||
# assert r is None or j==-2, f"{type(self)}'s {i}-th returned grad should be None, "\
|
||
# "because the input value is not jittor variable."
|
||
# else:
|
||
# new_ret.append(r)
|
||
# return new_ret
|
||
|
||
# def grad(self, grad_output):
|
||
# grad_inputs = self.split_grad(grad_output, self.input, self.dim)
|
||
# return grad_inputs
|
||
|
||
# def calculate_output_shape(self, input_tensors, axis):
|
||
# shape = list(input_tensors[0].shape)
|
||
# for tensor in input_tensors[1:]:
|
||
# shape[axis] += tensor.shape[axis]
|
||
# return tuple(shape)
|
||
|
||
# def split_grad(self, grad_output, input_tensors, axis):
|
||
# offset = []
|
||
# shapeVec = []
|
||
# dtypeVec = []
|
||
# for tensor in input_tensors:
|
||
# offset.append(tensor.shape[axis])
|
||
# dtypeVec.append(tensor.dtype)
|
||
# shapeVec.append(tensor.shape)
|
||
|
||
# attr_code = f"""
|
||
# op.jt_name = "splitwithsize";
|
||
# auto *attr = new SplitWithSizeAttr();
|
||
# attr->splitSize = {{ {", ".join(map(str, offset))} }};
|
||
# attr->dim = {axis};
|
||
# op.op_attr.reset(attr);
|
||
# """
|
||
|
||
# result = acl_cmd("SplitWithSize", [grad_output],
|
||
# output_dtypes=dtypeVec,
|
||
# output_shapes=shapeVec,
|
||
# attr_code=attr_code)
|
||
# return result
|
||
|
||
from .aclops.concat_op import ConcatACL
|
||
def concat(x, dim=0):
|
||
return ConcatACL()(x, dim)
|
||
|
||
# class GatherACL(Function):
|
||
|
||
# def __init__(self):
|
||
# super(GatherACL, self).__init__()
|
||
|
||
# def execute(self, input, dim, index):
|
||
# self.dim = dim
|
||
# self.index = index
|
||
# attr_code = f"""
|
||
# op.jt_name = "gather";
|
||
# GatherAttr *attr = new GatherAttr();
|
||
# attr->dim = {dim};
|
||
# op.op_attr.reset(attr);
|
||
# """
|
||
# result = acl_cmd("Gather", [input, index],
|
||
# output_dtypes=[input.dtype],
|
||
# output_shapes=[index.shape],
|
||
# attr_code=attr_code)[0]
|
||
# return result
|
||
|
||
# def grad(self, grad_output):
|
||
# tmp = jt.zeros(self.index.shape, dtype=grad_output.dtype)
|
||
# attr_code = f"""
|
||
# op.jt_name = "scatter";
|
||
# ScatterAttr *attr = new ScatterAttr();
|
||
# attr->axis = {self.dim};
|
||
# attr->reduction = {1};
|
||
# op.op_attr.reset(attr);
|
||
# """
|
||
# grad_input = acl_cmd("Scatter", [tmp, self.index, grad_output],
|
||
# output_dtypes=[grad_output.dtype],
|
||
# output_shapes=[tmp.shape],
|
||
# attr_code=attr_code)[0]
|
||
# return grad_input
|
||
|
||
from .aclops.gather_scatter_op import GatherACL
|
||
|
||
def gather_acl(input, dim, index):
|
||
return GatherACL()(input, dim, index)
|
||
|
||
def any_acl(input):
|
||
if jt.sum(input != 0).item() > 0:
|
||
return jt.array([True])
|
||
else:
|
||
return jt.array([False])
|
||
|
||
# class CumsumACL(Function):
|
||
|
||
# def __init__(self):
|
||
# super(CumsumACL, self).__init__()
|
||
|
||
# def execute(self, input, dim=-1):
|
||
# self.dim = dim
|
||
# attr_code = f"""
|
||
# op.jt_name = "cumsum";
|
||
# GatherAttr *attr = new GatherAttr();
|
||
# attr->dim = {dim};
|
||
# op.op_attr.reset(attr);
|
||
# """
|
||
# result = acl_cmd("Cumsum", [input],
|
||
# output_dtypes=[input.dtype],
|
||
# output_shapes=[input.shape],
|
||
# attr_code=attr_code)[0]
|
||
# return result
|
||
|
||
# def grad(self, grad_output):
|
||
# cumsum_attr_code = f"""
|
||
# op.jt_name = "cumsum";
|
||
# GatherAttr *attr = new GatherAttr();
|
||
# attr->dim = {self.dim};
|
||
# op.op_attr.reset(attr);
|
||
# """
|
||
# flip_attr_code = f"""
|
||
# op.jt_name = "flip";
|
||
# ReduceAttr *attr = new ReduceAttr();
|
||
# attr->axes = {{{self.dim}}};
|
||
# attr->prod_dim = {{{1}}};
|
||
# op.op_attr.reset(attr);
|
||
# """
|
||
# flipped_grad_output = acl_cmd("Flip", [grad_output],
|
||
# output_dtypes=[grad_output.dtype],
|
||
# output_shapes=[grad_output.shape],
|
||
# attr_code=flip_attr_code)[0]
|
||
# cumulative_grad = acl_cmd("Cumsum", [flipped_grad_output],
|
||
# output_dtypes=[grad_output.dtype],
|
||
# output_shapes=[grad_output.shape],
|
||
# attr_code=cumsum_attr_code)[0]
|
||
# grad_input = acl_cmd("Flip", [cumulative_grad],
|
||
# output_dtypes=[grad_output.dtype],
|
||
# output_shapes=[grad_output.shape],
|
||
# attr_code=flip_attr_code)[0]
|
||
# return grad_input
|
||
|
||
from .aclops.cumsum_op import CumsumACL
|
||
|
||
def cumsum_acl(input, dim=-1):
|
||
return CumsumACL()(input, dim)
|
||
|
||
def cumprod_acl(x, dim=None):
|
||
x = jt.log(x)
|
||
x = cumsum_acl(x, dim=dim)
|
||
return jt.exp(x)
|
||
|
||
# class IndexACL(Function):
|
||
|
||
# def __init__(self):
|
||
# super(IndexACL, self).__init__()
|
||
|
||
# def execute(self, inshape: list, dim=None, dtype="int32"):
|
||
# # zeros a tensor, shape is inshape, dtype is dtype
|
||
# dim_input = dim
|
||
# if dim == None:
|
||
# dim = [i for i in range(len(inshape))]
|
||
# elif type(dim) == int:
|
||
# dim = [dim]
|
||
# results = []
|
||
# extra_data = {}
|
||
# extra_data["dim_count"] = len(dim)
|
||
|
||
# for i, d in enumerate(dim):
|
||
# max_len = inshape[d]
|
||
|
||
# extra_data[f"dim_{i}_start"] = 0
|
||
# extra_data[f"dim_{i}_end"] = max_len
|
||
# extra_data[f"dim_{i}_step"] = 1
|
||
|
||
# tmp = jt.zeros(max_len, dtype=dtype)
|
||
# range_attr_code = f"""
|
||
# op.jt_name = "range";
|
||
# RangeAttr *attr = new RangeAttr();
|
||
# attr->start = data["dim_{i}_start"];
|
||
# attr->end = data["dim_{i}_end"];
|
||
# attr->step = data["dim_{i}_step"];
|
||
# op.op_attr.reset(attr);
|
||
# """
|
||
# result = acl_cmd_forward("Range", [],
|
||
# output_dtypes=[tmp.dtype],
|
||
# output_shapes=[tmp.shape],
|
||
# attr_code=range_attr_code,
|
||
# extra_data=extra_data)[0]
|
||
# broadcast_dims = list(range(len(inshape)))
|
||
# broadcast_dims.remove(d)
|
||
# result = jt.broadcast(result,
|
||
# shape=inshape,
|
||
# dims=broadcast_dims)
|
||
# results.append(result)
|
||
|
||
# if len(results) != 1 or dim_input == None:
|
||
# return tuple(results)
|
||
# elif len(results) == 1 and dim_input != None:
|
||
# return results[0]
|
||
# else:
|
||
# return results
|
||
|
||
# def grad(self, grad_output):
|
||
# return grad_output
|
||
|
||
from .aclops.index_op import IndexACL
|
||
|
||
def index_acl(inshape: Union[jt.Var, list], dim=None, dtype="int32"):
|
||
if isinstance(inshape, jt.Var):
|
||
inshape = inshape.shape
|
||
return IndexACL()(inshape, dim, dtype)
|
||
|
||
# class ScatterACL(Function):
|
||
|
||
# def __init__(self):
|
||
# super(ScatterACL, self).__init__()
|
||
|
||
# def execute(self, input, dim, index, src, reduce='void'):
|
||
# self.dim = dim
|
||
# self.index = index
|
||
# self.reduce = reduce
|
||
# attr_code = f"""
|
||
# op.jt_name = "scatter";
|
||
# ScatterAttr *attr = new ScatterAttr();
|
||
# attr->axis = {dim};
|
||
# attr->reduction = {1 if reduce == 'add' else 2 if reduce == 'mul' else 0};
|
||
# op.op_attr.reset(attr);
|
||
# """
|
||
# result = acl_cmd("Scatter", [input, self.index, src],
|
||
# output_dtypes=[input.dtype],
|
||
# output_shapes=[input.shape],
|
||
# attr_code=attr_code)[0]
|
||
# return result
|
||
|
||
# def grad(self, grad_output):
|
||
# attr_code = f"""
|
||
# op.jt_name = "gather";
|
||
# GatherAttr *attr = new GatherAttr();
|
||
# attr->dim = {self.dim};
|
||
# op.op_attr.reset(attr);
|
||
# """
|
||
# grad_input = acl_cmd("Gather", [grad_output, self.index],
|
||
# output_dtypes=[grad_output.dtype],
|
||
# output_shapes=[self.index.shape],
|
||
# attr_code=attr_code)[0]
|
||
# return grad_output, None, None, grad_input
|
||
|
||
from .aclops.gather_scatter_op import ScatterACL
|
||
def scatter_acl(input, dim, index, src, reduce='void'):
|
||
return ScatterACL()(input, dim, index, src, reduce)
|
||
|
||
# class WhereACL(Function):
|
||
|
||
# def __init__(self):
|
||
# super(WhereACL, self).__init__()
|
||
|
||
# def execute(self, condition, x=None, y=None):
|
||
# # case 1 (unary)
|
||
# if y is None:
|
||
# self.unary = True
|
||
|
||
# # In this case, `condition` is the input, while `x` is dtype
|
||
# result = nonzero_acl(condition).t()
|
||
# result = [result[i] for i in range(result.size(0))]
|
||
# return result
|
||
# # The return value should be a tuple, but even we set to tuple here, it will be convert to a list in `Function.__call__`.
|
||
|
||
# # case 2 (cond ? x : y)
|
||
# else:
|
||
# self.condition = condition
|
||
|
||
# if x.dtype != y.dtype:
|
||
# if x.dtype == jt.float32:
|
||
# y = y.float32()
|
||
# elif y.dtype == jt.float32:
|
||
# x = x.float32()
|
||
# else:
|
||
# x = x.to(y.dtype)
|
||
|
||
# self.x = x
|
||
# self.y = y
|
||
|
||
# result = acl_cmd("Where", [condition, x, y],
|
||
# output_dtypes=[x.dtype],
|
||
# output_shapes=[x.shape],
|
||
# attr_code="op.jt_name=\"where\";")[0]
|
||
# return result
|
||
|
||
# def grad(self, grad_output):
|
||
# if hasattr(self, 'unary') and self.unary:
|
||
# return grad_output
|
||
# else:
|
||
# tmp = jt.zeros(grad_output.shape, dtype=grad_output.dtype)
|
||
# grad_x = acl_cmd("Where", [self.condition, grad_output, tmp],
|
||
# output_dtypes=[self.x.dtype],
|
||
# output_shapes=[self.x.shape],
|
||
# attr_code="op.jt_name=\"where\";")[0]
|
||
|
||
# grad_y = acl_cmd("Where", [self.condition, tmp, grad_output],
|
||
# output_dtypes=[self.y.dtype],
|
||
# output_shapes=[self.y.shape],
|
||
# attr_code="op.jt_name=\"where\";")[0]
|
||
# return grad_output, grad_x, grad_y
|
||
|
||
from .aclops.where_op import WhereACL
|
||
|
||
def where_acl(condition, x=None, y=None):
|
||
return WhereACL()(condition, x, y)
|
||
|
||
# class NonzeroACL(Function):
|
||
|
||
# def __init__(self):
|
||
# super(NonzeroACL, self).__init__()
|
||
|
||
# def execute(self, x):
|
||
# attr_code = f"""
|
||
# op.jt_name = "nonzero";
|
||
# """
|
||
# nonzero_cnt = (x != 0.0).sum().item()
|
||
|
||
# result = acl_cmd("Nonzero", [x],
|
||
# output_dtypes=['int64'],
|
||
# output_shapes=[(nonzero_cnt, x.ndim)],
|
||
# attr_code=attr_code)[0]
|
||
|
||
# return result
|
||
|
||
# def grad(self, grad_output):
|
||
# return grad_output
|
||
|
||
from .aclops.where_op import NonzeroACL
|
||
|
||
def nonzero_acl(x):
|
||
return NonzeroACL()(x)
|
||
|
||
# class FloorIntACL(Function):
|
||
|
||
# def __init__(self):
|
||
# super(FloorIntACL, self).__init__()
|
||
|
||
# def execute(self, input):
|
||
# self.shape = input.shape
|
||
# result = acl_cmd("Floor", [input],
|
||
# output_dtypes=[input.dtype],
|
||
# output_shapes=[input.shape],
|
||
# attr_code="op.jt_name=\"floor\";")[0]
|
||
# return result
|
||
|
||
# def grad(self, grad_output):
|
||
# return jt.zeros(self.shape, dtype=grad_output.dtype)
|
||
|
||
from .aclops.floor_op import FloorIntACL
|
||
|
||
def floor_int_acl(x):
|
||
return FloorIntACL()(x)
|
||
|
||
# def caculate_shape(tensors):
|
||
# if isinstance(tensors, jt.Var):
|
||
# # tensors = tensors[0]
|
||
# return tensors.shape
|
||
# elif isinstance(tensors, (int, float)):
|
||
# return []
|
||
# elif isinstance(tensors, (list, tuple)):
|
||
# # return [caculate_shape(tensor) for tensor in tensors]
|
||
# sub_shape = caculate_shape(tensors[0])
|
||
# return [len(tensors)] + sub_shape
|
||
# else:
|
||
# assert False, f"not implemented for {type(tensors)}"
|
||
|
||
# def can_broadcast_and_shape(shape1, shape2):
|
||
# """
|
||
# 检查两个张量是否可以广播,并返回广播后的形状。
|
||
|
||
# 参数:
|
||
# - shape1: 第一个张量的形状(tuple 或 list)
|
||
# - shape2: 第二个张量的形状(tuple 或 list)
|
||
|
||
# 返回:
|
||
# - can_broadcast: 布尔值,表示是否可以广播
|
||
# - broadcast_shape: 如果可以广播,返回广播后的形状;否则返回 None
|
||
# """
|
||
# # 将形状转换为元组,以防输入是列表
|
||
# shape1 = tuple(shape1)
|
||
# shape2 = tuple(shape2)
|
||
|
||
# # 使两个形状的长度一致,通过在前面补1
|
||
# len1, len2 = len(shape1), len(shape2)
|
||
# if len1 < len2:
|
||
# shape1 = (1, ) * (len2 - len1) + shape1
|
||
# elif len2 < len1:
|
||
# shape2 = (1, ) * (len1 - len2) + shape2
|
||
|
||
# broadcast_shape = []
|
||
|
||
# # 从最后一维开始检查每一维度
|
||
# for dim1, dim2 in zip(shape1, shape2):
|
||
# if dim1 == dim2:
|
||
# broadcast_shape.append(dim1)
|
||
# elif dim1 == 1:
|
||
# broadcast_shape.append(dim2)
|
||
# elif dim2 == 1:
|
||
# broadcast_shape.append(dim1)
|
||
# else:
|
||
# # 如果在某一维度上不兼容,则不能广播
|
||
# return False, None
|
||
|
||
# return True, tuple(broadcast_shape)
|
||
|
||
# class GetItemACL(Function):
|
||
|
||
# def __init__(self):
|
||
# self.type_ = 'notype'
|
||
|
||
# def stride(self, x, dim):
|
||
# stride = 1
|
||
# for i in range(dim + 1, len(x.shape)):
|
||
# stride *= x.shape[i]
|
||
# return stride
|
||
|
||
# def execute(self, x, slices, return_x=None):
|
||
# if isinstance(slices, jt.Var) and slices.dtype == 'bool':
|
||
# # assert False, "not support bool type now"
|
||
# #TODO:优化
|
||
# assert x.shape == slices.shape, "shape not match"
|
||
# output_len = slices.sum().item()
|
||
# # output = jt.empty((output_len,),dtype=x.dtype)
|
||
# x_len = x.numel()
|
||
# output = jt.empty((x_len), dtype=x.dtype)
|
||
# outputs = [output]
|
||
# inputs = [x, slices]
|
||
# # print(inputs,outputs)
|
||
# # print(output.shape)
|
||
# self.mask = slices
|
||
# self.type_ = 'mask'
|
||
# attr_code = f"""
|
||
# op.jt_name = "maskedselect";
|
||
# """
|
||
# result = acl_cmd("MaskedSelect",
|
||
# inputs=inputs,
|
||
# outputs=outputs,
|
||
# attr_code=attr_code)[0]
|
||
# result = result[:output_len]
|
||
# result.sync()
|
||
# return result
|
||
# self.x_shape = x.shape
|
||
|
||
# if not isinstance(slices, tuple):
|
||
# slices = (slices, )
|
||
# slices = list(slices)
|
||
# for i, s in enumerate(slices):
|
||
# if isinstance(s, int) and s < 0:
|
||
# slices[i] = s + x.shape[i]
|
||
# slices = tuple(slices)
|
||
# slices_list = list(slices)
|
||
# # if not isinstance(slices[0], slice):
|
||
# #check slices contains slice type
|
||
# contains_slice = False
|
||
# for s in slices:
|
||
# if not isinstance(s, jt.Var) and (isinstance(s, slice)
|
||
# or s == Ellipsis):
|
||
# contains_slice = True
|
||
# break
|
||
# if not contains_slice:
|
||
# indices = []
|
||
# output_shape = []
|
||
# slices_len = len(slices)
|
||
# boardcast_shape = caculate_shape(slices_list[0])
|
||
# for ii in range(1, len(slices)):
|
||
# dd, boardcast_shape = can_broadcast_and_shape(
|
||
# boardcast_shape, caculate_shape(slices_list[ii]))
|
||
# assert dd is True, "can not broadcast"
|
||
# output_shape = boardcast_shape
|
||
# output_shape += x.shape[slices_len:]
|
||
# if output_shape == []:
|
||
# output_shape = [1]
|
||
# for ii in slices:
|
||
# indices.append(jt.Var(ii).int32())
|
||
# if isinstance(slices[0], jt.Var) or isinstance(
|
||
# slices[0], int) or isinstance(
|
||
# slices[0], list) or isinstance(slices[0], tuple):
|
||
# self.indices = indices
|
||
# inputs = [x] + indices
|
||
# attr_code = f"""
|
||
# op.jt_name = "index";
|
||
# """
|
||
# self.type_ = 'index'
|
||
# result = acl_cmd("Index",
|
||
# inputs=inputs,
|
||
# output_dtypes=[x.dtype],
|
||
# output_shapes=[output_shape],
|
||
# attr_code=attr_code)[0]
|
||
# result.sync()
|
||
# return result
|
||
# assert contains_slice, "slice type error"
|
||
# x_dim = len(x.shape)
|
||
# slices = list(slices)
|
||
# for s in slices:
|
||
# if not isinstance(s, jt.Var) and s == Ellipsis:
|
||
# slices = slices[:slices.index(s)] + [
|
||
# slice(None, None, None)
|
||
# ] * (x_dim - len(slices) + 1) + slices[slices.index(s) +
|
||
# 1:]
|
||
# break
|
||
# slices = tuple(slices)
|
||
|
||
# if len(slices) < x_dim:
|
||
# slices += (slice(None, None, None), ) * (x_dim - len(slices))
|
||
# inputs = [x]
|
||
# sizes = []
|
||
# begins = []
|
||
# ends = []
|
||
# steps = []
|
||
# dims = []
|
||
# squeeze_dims = []
|
||
|
||
# extra_data = {}
|
||
# if len(slices):
|
||
# extra_data["a"] = len(slices)
|
||
# for dim, s in enumerate(slices):
|
||
# if isinstance(s, int):
|
||
# s = slice(s, s + 1, 1)
|
||
# squeeze_dims.append(dim)
|
||
# if isinstance(s, jt.Var):
|
||
# assert False, "jt.Var not supported"
|
||
# start, stop, step = s.indices(x.size(dim))
|
||
# size = (stop - start - 1) // step + 1
|
||
# # stride = self.stride(x, dim) * step
|
||
# sizes.append(size)
|
||
# extra_data[str(dim * 3)] = start
|
||
# extra_data[str(dim * 3 + 1)] = stop
|
||
# extra_data[str(dim * 3 + 2)] = step
|
||
|
||
# steps.append(step)
|
||
# begins.append(start)
|
||
# ends.append(stop)
|
||
# dims.append(dim)
|
||
# else:
|
||
# extra_data["a"] = -1
|
||
# sizes = [1]
|
||
# steps = [1]
|
||
# self.type_ = 'slicev2'
|
||
|
||
# # for backward
|
||
# self.begins = begins
|
||
# self.ends = ends
|
||
# self.steps = steps
|
||
# self.dims = dims
|
||
|
||
# self.slices = slices
|
||
# attr_code = """
|
||
# op.jt_name = "slicev2";
|
||
# StrideAttr *attr = new StrideAttr();
|
||
|
||
# int slice_dim = data["a"];
|
||
|
||
# if(slice_dim == -1) {
|
||
# attr->begins = {};
|
||
# attr->ends = {};
|
||
# attr->steps = {1};
|
||
# attr->axes = {};
|
||
# } else {
|
||
# vector<long int> begins;
|
||
# vector<long int> ends;
|
||
# vector<long int> steps;
|
||
# vector<long int> dims;
|
||
# for(int dim = 0; dim < slice_dim; dim++) {
|
||
# dims.push_back(dim);
|
||
# begins.push_back(data[std::to_string(dim*3)]);
|
||
# ends.push_back(data[std::to_string(dim*3+1)]);
|
||
# steps.push_back(data[std::to_string(dim*3+2)]);
|
||
# }
|
||
# attr->begins = begins;
|
||
# attr->ends = ends;
|
||
# attr->steps = steps;
|
||
# attr->axes = dims;
|
||
# }
|
||
# op.op_attr.reset(attr);
|
||
# """
|
||
# result = acl_cmd_forward("SliceV2",
|
||
# inputs,
|
||
# output_dtypes=[x.dtype],
|
||
# output_shapes=[jt.empty(sizes).shape],
|
||
# attr_code=attr_code,
|
||
# extra_data=extra_data)[0]
|
||
# self.squeeze_dims = squeeze_dims
|
||
# for dim in squeeze_dims[::-1]:
|
||
# result = jt.squeeze(result, dim)
|
||
# result.sync()
|
||
# return result
|
||
|
||
# def grad(self, grad_output):
|
||
# if self.type_ == 'index':
|
||
# indices = self.indices
|
||
# inputs = [grad_output] + indices
|
||
# attr_code = f"""
|
||
# op.jt_name = "indexputimplaccumulate";
|
||
# """
|
||
# outputs = [jt.zeros(self.x_shape, dtype=grad_output.dtype)]
|
||
# # breakpoint()
|
||
# result = acl_cmd("IndexPutImplAccumulate",
|
||
# inputs=inputs,
|
||
# outputs=outputs,
|
||
# attr_code=attr_code)[0]
|
||
# result.sync()
|
||
# return result, None
|
||
# elif self.type_ == 'slicev2':
|
||
# begins = self.begins
|
||
# ends = self.ends
|
||
# steps = self.steps
|
||
# dims = self.dims
|
||
# slices = self.slices
|
||
# #注意前向的维数可能会被压缩,所以这里要还原
|
||
# for dim in self.squeeze_dims:
|
||
# grad_output = jt.unsqueeze(grad_output, dim)
|
||
# #适配华为奇怪的要求,最后一个维度的step必须是1
|
||
# expand_dim = False
|
||
# if isinstance(slices[-1], slice):
|
||
# if slices[-1].step is not None and slices[-1].step != 1:
|
||
# slices = slices + (slice(None, None, None), )
|
||
# expand_dim = True
|
||
# elif isinstance(slices[-1], int):
|
||
# #注意最后一个维度是数字
|
||
# slices = list(slices)
|
||
# slices[-1] = slice(slices[-1], slices[-1] + 1, 1)
|
||
# slices = tuple(slices)
|
||
# slices = slices + (slice(None, None, None), )
|
||
# expand_dim = True
|
||
# else:
|
||
# assert False, "not supported"
|
||
# # x = x.unsqueeze(-1)
|
||
# if expand_dim:
|
||
# grad_output = grad_output.unsqueeze(-1)
|
||
# self.x_shape = self.x_shape + (1, )
|
||
# sizes = []
|
||
# begins = []
|
||
# ends = []
|
||
# steps = []
|
||
# dims = []
|
||
# for dim, s in enumerate(slices):
|
||
# if isinstance(s, int):
|
||
# s = slice(s, s + 1, 1)
|
||
# # squeeze_dims.append(dim)
|
||
# if isinstance(s, jt.Var):
|
||
# assert False, "jt.Var not supported"
|
||
# start, stop, step = s.indices(self.x_shape[dim])
|
||
# size = (stop - start - 1) // step + 1
|
||
# # stride = self.stride(x, dim) * step
|
||
# sizes.append(size)
|
||
# steps.append(step)
|
||
# begins.append(start)
|
||
# ends.append(stop)
|
||
# dims.append(dim)
|
||
# if not sizes:
|
||
# sizes = [1]
|
||
# steps = [1]
|
||
# attr_code = f"""
|
||
# op.jt_name = "stridedsliceassignv2";
|
||
# StrideAttr *attr = new StrideAttr();
|
||
# attr->begins = {{ {", ".join(map(str, begins))} }};
|
||
# attr->ends = {{ {", ".join(map(str, ends))} }};
|
||
# attr->steps = {{ {", ".join(map(str, steps))} }};
|
||
# attr->axes = {{ {", ".join(map(str, dims))} }};
|
||
# op.op_attr.reset(attr);
|
||
# """
|
||
# inputs = [grad_output]
|
||
# outputs = [jt.zeros(self.x_shape, dtype=grad_output.dtype)]
|
||
# result = acl_cmd("StridedSliceAssignV2",
|
||
# inputs=inputs,
|
||
# outputs=outputs,
|
||
# attr_code=attr_code)[0]
|
||
# result.sync()
|
||
# if expand_dim:
|
||
# result = result.squeeze(-1)
|
||
# return result, None
|
||
# elif self.type_ == 'mask':
|
||
# return self.mask.float()
|
||
# pass
|
||
# else:
|
||
# assert False, f"grad not implemented for {self.type_}"
|
||
|
||
from .aclops.getitem_op import GetItemACL
|
||
def getitem_acl(x, slices, return_x=None):
|
||
# Transform numpy int to int
|
||
if isinstance(slices, (np.int8, np.int16, np.int32, np.int64)):
|
||
slices = int(slices)
|
||
if hasattr(np, 'int128') and isinstance(slices, np.int128):
|
||
slices = int(slices)
|
||
if hasattr(np, 'int256') and isinstance(slices, np.int256):
|
||
slices = int(slices)
|
||
|
||
## If not related to `None`, directly use `GetItemACL`
|
||
if slices is not None and (not isinstance(slices, Iterable)
|
||
or all([s is not None for s in slices])):
|
||
return GetItemACL()(x, slices, return_x)
|
||
|
||
## If related to `None`, filter out `None` first, then use `GetItemACL`, and finally insert `None` (new dimensions) back
|
||
|
||
# Transform to tuple
|
||
if isinstance(slices, int) or isinstance(slices, slice):
|
||
slices = (slices, )
|
||
assert isinstance(slices, tuple)
|
||
|
||
def get_insert_positions(slices):
|
||
result = []
|
||
pos = 0
|
||
|
||
not_none_cnt = len(slices) - slices.count(None)
|
||
for s in slices:
|
||
if isinstance(s, int):
|
||
continue
|
||
elif s is None:
|
||
result.append(pos)
|
||
pos += 1
|
||
elif s == Ellipsis:
|
||
pos += 1 + x.ndim - not_none_cnt
|
||
else:
|
||
pos += 1
|
||
|
||
return result
|
||
|
||
insert_positions = get_insert_positions(slices)
|
||
slices_without_none = tuple(s for s in slices if s is not None)
|
||
result = GetItemACL()(x, slices_without_none, return_x)
|
||
|
||
for i in insert_positions:
|
||
result = result.unsqueeze(i)
|
||
|
||
return result
|
||
|
||
# class SetItemACL(Function):
|
||
|
||
# def __init__(self):
|
||
# self.type_ = 'notype'
|
||
# self.value_var = True
|
||
|
||
# def stride(self, x, dim):
|
||
# stride = 1
|
||
# for i in range(dim + 1, len(x.shape)):
|
||
# stride *= x.shape[i]
|
||
# return stride
|
||
|
||
# def execute(self, x, slices, value):
|
||
# self.x_shape = x.shape
|
||
# self.input_slice = slices
|
||
# if not isinstance(value, jt.Var):
|
||
# self.value_var = False
|
||
# if isinstance(slices, jt.Var):
|
||
# if slices.dtype == "bool":
|
||
# slices_len = slices.sum().item()
|
||
# if slices_len == 0:
|
||
# return x
|
||
# if isinstance(value, int) or isinstance(value, float):
|
||
# value = jt.full((slices_len, ), value)
|
||
# assert slices.shape == x.shape, "setitem shape not match"
|
||
# assert len(value.shape) == 1, "value shape must be 1D"
|
||
# assert value.shape[
|
||
# 0] == slices_len, "value shape length must be equal to slices sum"
|
||
# self.type_ = 'mask'
|
||
# self.value_shape = value.shape
|
||
# inputs = [value, slices]
|
||
# outputs = [x.clone()]
|
||
# attr_code = f"""
|
||
# op.jt_name = "inplacemaskedscatter";
|
||
# """
|
||
# result = acl_cmd("InplaceMaskedScatter",
|
||
# inputs=inputs,
|
||
# outputs=outputs,
|
||
# attr_code=attr_code)[0]
|
||
# return result
|
||
|
||
# # assert isinstance(value,jt.Var), "value must be jt.Var"
|
||
# # self.value_shape = value.shape
|
||
# if not isinstance(slices, tuple):
|
||
# slices = (slices, )
|
||
# slices = list(slices)
|
||
# for i, s in enumerate(slices):
|
||
# if isinstance(s, int) and s < 0:
|
||
# slices[i] = x.shape[i] + s
|
||
# slices = tuple(slices)
|
||
# slices_list = list(slices)
|
||
# #check slices contains slice type
|
||
# contains_slice = False
|
||
# for s in slices:
|
||
# if not isinstance(s, jt.Var) and (isinstance(s, slice)
|
||
# or s == Ellipsis):
|
||
# contains_slice = True
|
||
# break
|
||
# if not contains_slice:
|
||
# indices = []
|
||
# value_shape = []
|
||
# slices_len = len(slices)
|
||
# boardcast_shape = caculate_shape(slices_list[0])
|
||
# for ii in range(1, len(slices)):
|
||
# dd, boardcast_shape = can_broadcast_and_shape(
|
||
# boardcast_shape, caculate_shape(slices_list[ii]))
|
||
# assert dd is True, "can not broadcast"
|
||
# value_shape = boardcast_shape
|
||
# value_shape += x.shape[slices_len:]
|
||
# if value_shape == []:
|
||
# value_shape = [1]
|
||
# if isinstance(value, int) or isinstance(value, float):
|
||
# value = jt.full(value_shape, value)
|
||
# self.value_shape = value_shape
|
||
# for ii in slices:
|
||
# indices.append(jt.Var(ii).int32())
|
||
# if isinstance(slices[0], jt.Var) or isinstance(
|
||
# slices[0], int) or isinstance(
|
||
# slices[0], list) or isinstance(slices[0], tuple):
|
||
# self.indices = indices
|
||
# self.type_ = 'index'
|
||
# attr_code = f"""
|
||
# op.jt_name = "indexputimpl";
|
||
# """
|
||
# inputs = [value] + indices
|
||
# outputs = [x.clone()]
|
||
# result = acl_cmd("IndexPutImpl",
|
||
# inputs=inputs,
|
||
# outputs=outputs,
|
||
# attr_code=attr_code)[0]
|
||
# # result.sync()
|
||
# return result
|
||
# assert "not support"
|
||
# assert contains_slice, "slice type error"
|
||
# x_dim = len(x.shape)
|
||
# slices = list(slices)
|
||
# for s in slices:
|
||
# if not isinstance(s, jt.Var) and s == Ellipsis:
|
||
# slices = slices[:slices.index(s)] + [
|
||
# slice(None, None, None)
|
||
# ] * (x_dim - len(slices) + 1) + slices[slices.index(s) +
|
||
# 1:]
|
||
# break
|
||
# slices = tuple(slices)
|
||
# self.input_slice = slices
|
||
# if len(slices) < x_dim:
|
||
# slices += (slice(None, None, None), ) * (x_dim - len(slices))
|
||
# sizes = []
|
||
# #适配华为奇怪的要求,最后一个维度的step必须是1
|
||
# expand_dim = False
|
||
# if isinstance(slices[-1], slice):
|
||
# if slices[-1].step is not None and slices[-1].step != 1:
|
||
# slices = slices + (slice(None, None, None), )
|
||
# expand_dim = True
|
||
|
||
# elif isinstance(slices[-1], int):
|
||
# #注意最后一个维度是数字
|
||
# slices = slices + (slice(None, None, None), )
|
||
# expand_dim = True
|
||
# # value = value.unsqueeze(-1)
|
||
# else:
|
||
# assert False, "not supported"
|
||
# x_shape = list(x.shape)
|
||
# if expand_dim:
|
||
# x_shape.append(1)
|
||
# x = x.unsqueeze(-1)
|
||
# value = value.unsqueeze(-1)
|
||
|
||
# squeeze_dims = []
|
||
# if isinstance(value, jt.Var):
|
||
# for dim, s in enumerate(slices):
|
||
# if isinstance(s, int):
|
||
# s = slice(s, s + 1, 1)
|
||
# squeeze_dims.append(dim)
|
||
|
||
# for dim in squeeze_dims:
|
||
# value = value.unsqueeze(dim)
|
||
|
||
# extra_data = {}
|
||
# if len(slices):
|
||
# extra_data["a"] = len(slices)
|
||
# for dim, s in enumerate(slices):
|
||
# if isinstance(s, int):
|
||
# s = slice(s, s + 1, 1)
|
||
# if isinstance(s, jt.Var):
|
||
# assert False, "jt.Var not supported"
|
||
# start, stop, step = s.indices(x_shape[dim])
|
||
# size = (stop - start - 1) // step + 1
|
||
# sizes.append(size)
|
||
# extra_data[str(dim * 3)] = start
|
||
# extra_data[str(dim * 3 + 1)] = stop
|
||
# extra_data[str(dim * 3 + 2)] = step
|
||
# else:
|
||
# extra_data["a"] = -1
|
||
# sizes = [1]
|
||
# steps = [1]
|
||
# if isinstance(value, int) or isinstance(value, float):
|
||
# value = jt.full(sizes, value)
|
||
# self.type_ = 'slicev2'
|
||
# attr_code = """
|
||
# op.jt_name = "stridedsliceassignv2";
|
||
# StrideAttr *attr = new StrideAttr();
|
||
# int slice_dim = data["a"];
|
||
|
||
# if(slice_dim == -1) {
|
||
# attr->begins = {};
|
||
# attr->ends = {};
|
||
# attr->steps = {1};
|
||
# attr->axes = {};
|
||
# } else {
|
||
# vector<long int> begins;
|
||
# vector<long int> ends;
|
||
# vector<long int> steps;
|
||
# vector<long int> dims;
|
||
# for(int dim = 0; dim < slice_dim; dim++) {
|
||
# dims.push_back(dim);
|
||
# begins.push_back(data[std::to_string(dim*3)]);
|
||
# ends.push_back(data[std::to_string(dim*3+1)]);
|
||
# steps.push_back(data[std::to_string(dim*3+2)]);
|
||
# }
|
||
# attr->begins = begins;
|
||
# attr->ends = ends;
|
||
# attr->steps = steps;
|
||
# attr->axes = dims;
|
||
# }
|
||
# op.op_attr.reset(attr);
|
||
# """
|
||
# self.value_shape = value.shape
|
||
# inputs = [value]
|
||
# outputs = [x.clone()]
|
||
# result = acl_cmd_forward("StridedSliceAssignV2",
|
||
# inputs=inputs,
|
||
# outputs=outputs,
|
||
# attr_code=attr_code,
|
||
# extra_data=extra_data)[0]
|
||
# if expand_dim:
|
||
# result = result.squeeze(-1)
|
||
# # result.sync()
|
||
# return result
|
||
|
||
# def grad(self, grad_output):
|
||
# value_grad = None
|
||
# if self.value_var:
|
||
# value_grad = grad_output[self.input_slice]
|
||
# grad_output[self.input_slice] = jt.zeros(self.value_shape)
|
||
# return grad_output, None, value_grad
|
||
|
||
from .aclops.setitem_op import SetItemACL
|
||
def setitem_acl(x, slices, value):
|
||
res = SetItemACL()(x, slices, value)
|
||
return x.assign(res)
|
||
|
||
# class BmmACL(Function):
|
||
|
||
# def __init__(self, trans_x2=False):
|
||
# super(BmmACL, self).__init__()
|
||
# self.trans_x2 = trans_x2
|
||
|
||
# def execute(self, x1, x2):
|
||
# self.input = [x1, x2]
|
||
# result = acl_cmd(
|
||
# "BatchMatMul", [x1, x2],
|
||
# output_dtypes=[x1.dtype],
|
||
# output_shapes=[
|
||
# x1.shape[:-1] +
|
||
# x2.shape[-2:-1] if self.trans_x2 else x1.shape[:-1] +
|
||
# x2.shape[-1:]
|
||
# ],
|
||
# attr_code="op.jt_name=\"bmm_trans_1\";"
|
||
# if self.trans_x2 else "op.jt_name=\"bmm\";")[0]
|
||
|
||
# return result
|
||
|
||
# def grad(self, grad_output):
|
||
# x1, x2 = self.input
|
||
# if len(x1) != len(x2):
|
||
# reshape_grad_x2 = True
|
||
# else:
|
||
# reshape_grad_x2 = False
|
||
# grad_x1 = acl_cmd("BatchMatMul", [grad_output, x2],
|
||
# output_dtypes=[x1.dtype],
|
||
# output_shapes=[
|
||
# grad_output.shape[:-1] +
|
||
# x2.shape[-2:-1] if not self.trans_x2 else
|
||
# grad_output.shape[:-1] + x1.shape[-1:]
|
||
# ],
|
||
# attr_code="op.jt_name=\"bmm_trans_1\";" if
|
||
# not self.trans_x2 else "op.jt_name=\"bmm\";")[0]
|
||
# if self.trans_x2:
|
||
# if reshape_grad_x2:
|
||
# output_shape = grad_output.shape[1:-2] + grad_output.shape[
|
||
# -1:] + x1.shape[-1:]
|
||
# grad_x2 = acl_cmd(
|
||
# "BatchMatMul", [
|
||
# grad_output.reshape(-1, grad_output.shape[-1]),
|
||
# x1.reshape(-1, x1.shape[-1])
|
||
# ],
|
||
# output_dtypes=[x2.dtype],
|
||
# output_shapes=[output_shape],
|
||
# attr_code="op.jt_name=\"bmm_trans_0\";")[0]
|
||
# else:
|
||
# output_shape = grad_output.shape[:-2] + grad_output.shape[
|
||
# -1:] + x1.shape[-1:]
|
||
# grad_x2 = acl_cmd(
|
||
# "BatchMatMul", [grad_output, x1],
|
||
# output_dtypes=[x2.dtype],
|
||
# output_shapes=[output_shape],
|
||
# attr_code="op.jt_name=\"bmm_trans_0\";")[0]
|
||
# else:
|
||
# if reshape_grad_x2:
|
||
# output_shape = x1.shape[1:-2] + x1.shape[
|
||
# -1:] + grad_output.shape[-1:]
|
||
# grad_x2 = acl_cmd(
|
||
# "BatchMatMul", [
|
||
# x1.reshape(-1, x1.shape[-1]),
|
||
# grad_output.reshape(-1, grad_output.shape[-1])
|
||
# ],
|
||
# output_dtypes=[x2.dtype],
|
||
# output_shapes=[output_shape],
|
||
# attr_code="op.jt_name=\"bmm_trans_0\";")[0]
|
||
# else:
|
||
# output_shape = x1.shape[:-2] + x1.shape[
|
||
# -1:] + grad_output.shape[-1:]
|
||
# grad_x2 = acl_cmd(
|
||
# "BatchMatMul", [x1, grad_output],
|
||
# output_dtypes=[x2.dtype],
|
||
# output_shapes=[output_shape],
|
||
# attr_code="op.jt_name=\"bmm_trans_0\";")[0]
|
||
# if len(grad_x1.shape) > len(x1.shape):
|
||
# grad_x1 = grad_x1.sum(0)
|
||
# if len(grad_x2.shape) > len(x2.shape):
|
||
# grad_x2 = grad_x2.sum(0)
|
||
# return grad_x1, grad_x2
|
||
|
||
from .aclops.bmm_op import BmmACL
|
||
def bmm_acl(x1, x2):
|
||
return BmmACL()(x1, x2)
|
||
|
||
def bmm_transpose_acl(x1, x2):
|
||
return BmmACL(True)(x1, x2)
|
||
|
||
# class MatmulACL(Function):
|
||
|
||
# def __init__(self, trans_x2=False):
|
||
# super(MatmulACL, self).__init__()
|
||
# self.trans_x2 = trans_x2
|
||
|
||
# def execute(self, x1, x2):
|
||
# self.input = [x1, x2]
|
||
# result = acl_cmd(
|
||
# "MatMul", [x1, x2],
|
||
# output_dtypes=[x1.dtype],
|
||
# output_shapes=[
|
||
# x1.shape[:-1] +
|
||
# x2.shape[-2:-1] if self.trans_x2 else x1.shape[:-1] +
|
||
# x2.shape[-1:]
|
||
# ],
|
||
# attr_code="op.jt_name=\"matmul_trans_1\";"
|
||
# if self.trans_x2 else "op.jt_name=\"matmul\";")[0]
|
||
# return result
|
||
|
||
# def grad(self, grad_output):
|
||
# x1, x2 = self.input
|
||
# if len(x1) != len(x2):
|
||
# reshape_grad_x2 = True
|
||
# else:
|
||
# reshape_grad_x2 = False
|
||
# grad_x1 = acl_cmd(
|
||
# "MatMul", [grad_output, x2],
|
||
# output_dtypes=[x1.dtype],
|
||
# output_shapes=[
|
||
# grad_output.shape[:-1] + x2.shape[-2:-1]
|
||
# if not self.trans_x2 else grad_output.shape[:-1] +
|
||
# x2.shape[-1:]
|
||
# ],
|
||
# attr_code="op.jt_name=\"matmul_trans_1\";"
|
||
# if not self.trans_x2 else "op.jt_name=\"matmul\";")[0]
|
||
|
||
# if self.trans_x2:
|
||
# if reshape_grad_x2:
|
||
# output_shape = grad_output.shape[1:-2] + grad_output.shape[
|
||
# -1:] + x1.shape[-1:]
|
||
# grad_x2 = acl_cmd(
|
||
# "MatMul", [
|
||
# grad_output.reshape(-1, grad_output.shape[-1]),
|
||
# x1.reshape(-1, x1.shape[-1])
|
||
# ],
|
||
# output_dtypes=[x2.dtype],
|
||
# output_shapes=[output_shape],
|
||
# attr_code="op.jt_name=\"matmul_trans_0\";")[0]
|
||
# else:
|
||
# output_shape = grad_output.shape[:-2] + grad_output.shape[
|
||
# -1:] + x1.shape[-1:]
|
||
# grad_x2 = acl_cmd(
|
||
# "MatMul", [grad_output, x1],
|
||
# output_dtypes=[x2.dtype],
|
||
# output_shapes=[output_shape],
|
||
# attr_code="op.jt_name=\"matmul_trans_0\";")[0]
|
||
# else:
|
||
# if reshape_grad_x2:
|
||
# output_shape = x1.shape[1:-2] + x1.shape[
|
||
# -1:] + grad_output.shape[-1:]
|
||
# grad_x2 = acl_cmd(
|
||
# "MatMul", [
|
||
# x1.reshape(-1, x1.shape[-1]),
|
||
# grad_output.reshape(-1, grad_output.shape[-1])
|
||
# ],
|
||
# output_dtypes=[x2.dtype],
|
||
# output_shapes=[output_shape],
|
||
# attr_code="op.jt_name=\"matmul_trans_0\";")[0]
|
||
# else:
|
||
# output_shape = x1.shape[:-2] + x1.shape[
|
||
# -1:] + grad_output.shape[-1:]
|
||
# grad_x2 = acl_cmd(
|
||
# "MatMul", [x1, grad_output],
|
||
# output_dtypes=[x2.dtype],
|
||
# output_shapes=[output_shape],
|
||
# attr_code="op.jt_name=\"matmul_trans_0\";")[0]
|
||
# return grad_x1, grad_x2
|
||
|
||
from .aclops.matmul_op import MatmulACL
|
||
def matmul_acl(x1, x2):
|
||
return MatmulACL()(x1, x2)
|
||
|
||
def matmul_transpose_acl(x1, x2):
|
||
return MatmulACL(True)(x1, x2)
|
||
|
||
# class TransPoseACL(Function):
|
||
|
||
# def __init__(self):
|
||
# super(TransPoseACL, self).__init__()
|
||
|
||
# def execute(self, x, *dim):
|
||
# self.input = x
|
||
# if len(dim) == 1 and isinstance(dim[0], Sequence):
|
||
# dim = dim[0]
|
||
# elif len(dim) == 2:
|
||
# axes = list(range(x.ndim))
|
||
# a, b = dim
|
||
# axes[a], axes[b] = axes[b], axes[a]
|
||
# dim = axes
|
||
|
||
# attr_code = f"""
|
||
# op.jt_name = "transpose";
|
||
# ReduceAttr *attr = new ReduceAttr();
|
||
# attr->axes = {{ {", ".join(map(str, dim))} }};
|
||
# op.op_attr.reset(attr);
|
||
# """
|
||
# # calculate output shape
|
||
# output_shape = [x.shape[i] for i in dim]
|
||
# output = acl_cmd("Transpose", [x],
|
||
# output_dtypes=[x.dtype],
|
||
# output_shapes=[jt.empty(output_shape).shape],
|
||
# attr_code=attr_code)[0]
|
||
# self.dim = dim
|
||
# return output
|
||
|
||
# def grad(self, grad_output):
|
||
# dim = list(range(grad_output.ndim))
|
||
# for i, p in enumerate(self.dim):
|
||
# dim[p] = i
|
||
# output_shape = [grad_output.shape[i] for i in dim]
|
||
# attr_code = f"""
|
||
# op.jt_name = "transpose";
|
||
# ReduceAttr *attr = new ReduceAttr();
|
||
# attr->axes = {{ {", ".join(map(str, dim))} }};
|
||
# op.op_attr.reset(attr);
|
||
# """
|
||
# output = acl_cmd("Transpose", [grad_output],
|
||
# output_dtypes=[grad_output.dtype],
|
||
# output_shapes=[jt.empty(output_shape).shape],
|
||
# attr_code=attr_code)[0]
|
||
# return output
|
||
|
||
from .aclops.transpose_op import TransPoseACL
|
||
|
||
def transpose_acl(x, *dim):
|
||
return TransPoseACL()(x, *dim)
|
||
|
||
# class FlashAttentionACL(Function):
|
||
|
||
# def __init__(self,
|
||
# headnum,
|
||
# layout="BNSD",
|
||
# prefix=None,
|
||
# qstart=None,
|
||
# kvstart=None,
|
||
# scale=1.0,
|
||
# prob=1.0,
|
||
# pretokens=2147483647,
|
||
# nexttokens=2147483647,
|
||
# innerprecise=0,
|
||
# sparsemode=0,
|
||
# psetype=1):
|
||
# self.headnum = headnum
|
||
# self.layout = layout
|
||
# self.scale = scale
|
||
# self.prob = prob
|
||
# self.pretokens = pretokens
|
||
# self.nexttokens = nexttokens
|
||
# self.innerprecise = innerprecise
|
||
# self.sparsemode = sparsemode
|
||
# self.psetype = psetype
|
||
# self.prefix = prefix
|
||
# self.qstart = qstart
|
||
# self.kvstart = kvstart
|
||
|
||
# def execute(
|
||
# self,
|
||
# q,
|
||
# k,
|
||
# v,
|
||
# realshift=None,
|
||
# dropMask=None,
|
||
# paddingMask=None,
|
||
# attenMask=None,
|
||
# ):
|
||
# if self.layout == 'BSH':
|
||
# B, SQ, H = q.shape
|
||
# SKV = k.shape[1]
|
||
# N = self.headnum
|
||
# D = H / N
|
||
# elif self.layout == 'SBH':
|
||
# SQ, B, H = q.shape
|
||
# SKV = k.shape[0]
|
||
# N = self.headnum
|
||
# D = H / N
|
||
# elif self.layout == 'BSND':
|
||
# B, SQ, N, D = q.shape
|
||
# SKV = k.shape[1]
|
||
# elif self.layout == 'BNSD':
|
||
# B, N, SQ, D = q.shape
|
||
# SKV = k.shape[2]
|
||
# else:
|
||
# raise ValueError(f"got invalid input layout {self.layout}")
|
||
|
||
# output_shape = (B, N, SQ, 8)
|
||
|
||
# self.q = q
|
||
# self.k = k
|
||
# self.v = v
|
||
|
||
# self.prefix = self.prefix if self.prefix else [0 for _ in range(B)]
|
||
# self.qstart = self.qstart if self.qstart else [0 for _ in range(B)]
|
||
# self.kvstart = self.kvstart if self.kvstart else [
|
||
# 0 for _ in range(B)
|
||
# ]
|
||
|
||
# self.hasRealshift = (not realshift == None)
|
||
# self.hasDropmask = (not dropMask == None)
|
||
# self.hasPaddingmask = (not paddingMask == None)
|
||
# self.hasAttenmask = (not attenMask == None)
|
||
|
||
# # 待定,目前设为nullptr
|
||
# self.realshift = realshift if realshift else jt.zeros(
|
||
# B, N, SQ, SKV)
|
||
# self.dropMask = dropMask if dropMask else jt.ones(B, N, SQ, SKV)
|
||
# self.paddingMask = paddingMask if paddingMask else jt.zeros(
|
||
# B, N, SQ, SKV)
|
||
# self.attenMask = attenMask if attenMask else jt.zeros(SQ, SKV)
|
||
|
||
# attr_code = f"""
|
||
# op.jt_name = "flashattention";
|
||
# FlashAttentionAttr *attr = new FlashAttentionAttr();
|
||
# attr->scale = {self.scale};
|
||
# attr->keepProb = {self.prob};
|
||
# attr->preToken = {self.pretokens};
|
||
# attr->nextToken = {self.nexttokens};
|
||
# attr->headNum = {self.headnum};
|
||
# attr->inputLayout = "{self.layout}";
|
||
# attr->innerPrecise = {self.innerprecise};
|
||
# attr->sparseMode = {self.sparsemode};
|
||
# attr->psetype = {self.psetype};
|
||
# attr->prefix = {{ {", ".join(map(str, self.prefix))} }};
|
||
# attr->qStartIdx = {{ {", ".join(map(str, self.qstart))} }};
|
||
# attr->kvStartIdx = {{ {", ".join(map(str, self.kvstart))} }};
|
||
# attr->hasRealshift = {"true" if self.hasRealshift else "false"};
|
||
# attr->hasDropmask = {"true" if self.hasDropmask else "false"};
|
||
# attr->hasPaddingmask = {"true" if self.hasPaddingmask else "false"};
|
||
# attr->hasAttentmask = {"true" if self.hasAttenmask else "false"};
|
||
# op.op_attr.reset(attr);
|
||
# """
|
||
|
||
# inputs = [
|
||
# q, k, v, self.realshift, self.dropMask, self.paddingMask,
|
||
# self.attenMask
|
||
# ]
|
||
|
||
# result = acl_cmd(
|
||
# "FlashAttention",
|
||
# inputs,
|
||
# output_dtypes=["float", "float", q.dtype],
|
||
# output_shapes=[output_shape, output_shape, q.shape],
|
||
# attr_code=attr_code)
|
||
|
||
# self.maxout = result[0]
|
||
# self.sumout = result[1]
|
||
# self.attenout = result[2]
|
||
|
||
# return self.attenout
|
||
|
||
# def grad(self, dy):
|
||
# attr_code = f"""
|
||
# op.jt_name = "flashattentionbackward";
|
||
# FlashAttentionAttr *attr = new FlashAttentionAttr();
|
||
# attr->scale = {self.scale};
|
||
# attr->keepProb = {self.prob};
|
||
# attr->preToken = {self.pretokens};
|
||
# attr->nextToken = {self.nexttokens};
|
||
# attr->headNum = {self.headnum};
|
||
# attr->inputLayout = "{self.layout}";
|
||
# attr->innerPrecise = {self.innerprecise};
|
||
# attr->sparseMode = {self.sparsemode};
|
||
# attr->psetype = {self.psetype};
|
||
# attr->prefix = {{ {", ".join(map(str, self.prefix))} }};
|
||
# attr->qStartIdx = {{ {", ".join(map(str, self.qstart))} }};
|
||
# attr->kvStartIdx = {{ {", ".join(map(str, self.kvstart))} }};
|
||
# attr->hasRealshift = {"true" if self.hasRealshift else "false"};
|
||
# attr->hasDropmask = {"true" if self.hasDropmask else "false"};
|
||
# attr->hasPaddingmask = {"true" if self.hasPaddingmask else "false"};
|
||
# attr->hasAttentmask = {"true" if self.hasAttenmask else "false"};
|
||
# op.op_attr.reset(attr);
|
||
# """
|
||
# inputs = [
|
||
# self.q, self.k, self.v, dy, self.realshift, self.dropMask,
|
||
# self.paddingMask, self.attenMask, self.maxout, self.sumout,
|
||
# self.attenout
|
||
# ]
|
||
|
||
# result = acl_cmd(
|
||
# "FlashAttentionBackward",
|
||
# inputs,
|
||
# output_dtypes=[self.q.dtype, self.k.dtype, self.v.dtype],
|
||
# output_shapes=[self.q.shape, self.k.shape, self.v.shape],
|
||
# attr_code=attr_code)
|
||
# return result
|
||
|
||
class ReLUACL(Function):
|
||
|
||
def __init__(self):
|
||
super(ReLUACL, self).__init__()
|
||
|
||
def execute(self, x):
|
||
x = x.float32()
|
||
self.input = x
|
||
result = acl_cmd("ReLU", [x],
|
||
output_dtypes=[x.dtype],
|
||
output_shapes=[x.shape],
|
||
attr_code="op.jt_name=\"unary\";")[0]
|
||
return result
|
||
|
||
def grad(self, grad_output):
|
||
mask = acl_cmd("Greater",
|
||
[self.input, jt.zeros(self.input.shape)],
|
||
output_dtypes=[self.input.dtype],
|
||
output_shapes=[self.input.shape],
|
||
attr_code="op.jt_name=\"binary\";")[0]
|
||
grad_input = acl_cmd("Mul", [grad_output, mask],
|
||
output_dtypes=[grad_output.dtype],
|
||
output_shapes=[grad_output.shape],
|
||
attr_code="op.jt_name=\"binary\";")[0]
|
||
return grad_input
|
||
|
||
class ReLU(jt.nn.Module):
|
||
|
||
def __init__(self):
|
||
super(ReLU, self).__init__()
|
||
|
||
def execute(self, x):
|
||
return ReLUACL()(x)
|
||
|
||
def relu(x):
|
||
return ReLUACL()(x)
|
||
|
||
# class LeakyReLUACL(Function):
|
||
|
||
# def __init__(self):
|
||
# super(LeakyReLUACL, self).__init__()
|
||
|
||
# def execute(self, x, negative_slope=0.01):
|
||
# x = x.float32()
|
||
# self.input = x
|
||
# attr_code = f"""
|
||
# op.jt_name = "leakyrelu";
|
||
# LeakyReluAttr *attr = new LeakyReluAttr();
|
||
# attr->negativeSlope = {negative_slope};
|
||
# op.op_attr.reset(attr);
|
||
# """
|
||
# result = acl_cmd("LeakyReLU", [x],
|
||
# output_dtypes=[x.dtype],
|
||
# output_shapes=[x.shape],
|
||
# attr_code=attr_code)[0]
|
||
# self.negative_slope = negative_slope
|
||
# return result
|
||
|
||
# def grad(self, grad_output):
|
||
# attr_code = f"""
|
||
# op.jt_name = "leakyrelubackward";
|
||
# LeakyReluAttr *attr = new LeakyReluAttr();
|
||
# attr->negativeSlope = {self.negative_slope};
|
||
# attr->selfIsResult = false;
|
||
# op.op_attr.reset(attr);
|
||
# """
|
||
# grad_input = acl_cmd("LeakyReLUBackward",
|
||
# [grad_output, self.input],
|
||
# output_dtypes=[grad_output.dtype],
|
||
# output_shapes=[grad_output.shape],
|
||
# attr_code=attr_code)[0]
|
||
# return grad_input
|
||
|
||
from .aclops.relu_op import LeakyReLUACL
|
||
class LeakyReLU(jt.nn.Module):
|
||
|
||
def __init__(self, negative_slope=0.01):
|
||
super(LeakyReLU, self).__init__()
|
||
self.negative_slope = negative_slope
|
||
|
||
def execute(self, x):
|
||
return LeakyReLUACL()(x, self.negative_slope)
|
||
|
||
def leaky_relu(x, scale=0.01):
|
||
return LeakyReLUACL()(x, scale)
|
||
|
||
# class DropoutACL(Function):
|
||
|
||
# def __init__(self):
|
||
# super(DropoutACL, self).__init__()
|
||
|
||
# def execute(self, x, p=0.5, is_train=False):
|
||
# self.input = x
|
||
# num_elements = x.numel()
|
||
# aligned_elements = (num_elements + 127) // 128 * 128
|
||
# mask_shape = (aligned_elements // 8, )
|
||
# attr_code = f"""
|
||
# op.jt_name = "dropout";
|
||
# DropoutAttr *attr = new DropoutAttr();
|
||
# attr->p = {p};
|
||
# attr->train = {"true" if is_train else "false"};
|
||
# attr->seed = 0;
|
||
# attr->offset = 0;
|
||
# op.op_attr.reset(attr);
|
||
# """
|
||
# result = acl_cmd("Dropout", [x],
|
||
# output_dtypes=[x.dtype, "uint8"],
|
||
# output_shapes=[x.shape, mask_shape],
|
||
# attr_code=attr_code)
|
||
# self.maskout = result[1]
|
||
# return result[0]
|
||
|
||
# def grad(self, grad_output):
|
||
# attr_code = f"""
|
||
# op.jt_name = "dropoutbackward";
|
||
# DropoutAttr *attr = new DropoutAttr();
|
||
# attr->scale = 1.0;
|
||
# op.op_attr.reset(attr);
|
||
# """
|
||
# grad_input = acl_cmd("DropoutBackward",
|
||
# [grad_output, self.maskout],
|
||
# output_dtypes=[grad_output.dtype],
|
||
# output_shapes=[grad_output.shape],
|
||
# attr_code=attr_code)[0]
|
||
# return grad_input
|
||
|
||
from .aclops.dropout_op import DropoutACL
|
||
class Dropout(jt.nn.Module):
|
||
|
||
def __init__(self, p=0.5, is_train=False):
|
||
super(Dropout, self).__init__()
|
||
self.p = p
|
||
self.is_train = is_train
|
||
|
||
def execute(self, x):
|
||
return DropoutACL()(x, self.p, self.is_train)
|
||
|
||
def dropout_acl(x, p=0.5, is_train=False):
|
||
return DropoutACL()(x, p, is_train)
|
||
|
||
# class SiLUACL(Function):
|
||
|
||
# def __init__(self):
|
||
# super(SiLUACL, self).__init__()
|
||
|
||
# def execute(self, x):
|
||
# x = x.float32()
|
||
# inputs = [x]
|
||
# self.input = x
|
||
# outputs = [jt.empty(x.shape, x.dtype)]
|
||
# attr_code = f"""
|
||
# op.jt_name = "silu";
|
||
# """
|
||
# result = acl_cmd("SiLU",
|
||
# inputs=inputs,
|
||
# outputs=outputs,
|
||
# attr_code=attr_code)[0]
|
||
# return result
|
||
|
||
# def grad(self, grad_output):
|
||
# attr_code = f"""
|
||
# op.jt_name = "silubackward";
|
||
# """
|
||
# inputs = [grad_output, self.input]
|
||
# outputs = [jt.empty(grad_output.shape, grad_output.dtype)]
|
||
# grad_input = acl_cmd("SiLUBackward",
|
||
# inputs=inputs,
|
||
# outputs=outputs,
|
||
# attr_code=attr_code)[0]
|
||
# return grad_input
|
||
|
||
|
||
from .aclops.silu_op import SiLUACL
|
||
|
||
def silu_acl(x):
|
||
return SiLUACL()(x)
|
||
|
||
class SiLU(jt.nn.Module):
|
||
|
||
def __init__(self):
|
||
super(SiLU, self).__init__()
|
||
|
||
def execute(self, x):
|
||
return SiLUACL()(x)
|
||
|
||
# class SigmoidACL(Function):
|
||
|
||
# def __init__(self):
|
||
# super(SigmoidACL, self).__init__()
|
||
|
||
# def execute(self, x):
|
||
# x = x.float32()
|
||
# inputs = [x]
|
||
# outputs = [jt.empty(x.shape, x.dtype)]
|
||
# attr_code = f"""
|
||
# op.jt_name = "sigmoid";
|
||
# """
|
||
# result = acl_cmd("Sigmoid",
|
||
# inputs=inputs,
|
||
# outputs=outputs,
|
||
# attr_code=attr_code)[0]
|
||
# self.output = result
|
||
# return result
|
||
|
||
# def grad(self, grad_output):
|
||
# attr_code = f"""
|
||
# op.jt_name = "sigmoidbackward";
|
||
# """
|
||
# inputs = [grad_output, self.output]
|
||
# outputs = [jt.empty(grad_output.shape, grad_output.dtype)]
|
||
# grad_input = acl_cmd("SigmoidBackward",
|
||
# inputs=inputs,
|
||
# outputs=outputs,
|
||
# attr_code=attr_code)[0]
|
||
# return grad_input
|
||
|
||
|
||
from .aclops.sigmoid_op import SigmoidACL
|
||
|
||
def sigmoid_acl(x):
|
||
return SigmoidACL()(x)
|
||
|
||
class Sigmoid(jt.nn.Module):
|
||
|
||
def __init__(self):
|
||
super(Sigmoid, self).__init__()
|
||
|
||
def execute(self, x):
|
||
return SigmoidACL()(x)
|
||
|
||
class EmbeddingACL(Function):
|
||
|
||
def __init__(self):
|
||
super(EmbeddingACL, self).__init__()
|
||
|
||
def execute(
|
||
self,
|
||
indices,
|
||
weight,
|
||
):
|
||
inputs = [weight, indices]
|
||
self.indices = indices
|
||
self.weight_shape = weight.shape
|
||
output_shape = list(indices.shape) + list(weight.shape[1:])
|
||
outputs = [jt.empty(output_shape, weight.dtype)]
|
||
attr_code = f"""
|
||
op.jt_name = "embedding";
|
||
"""
|
||
result = acl_cmd("Embedding",
|
||
inputs=inputs,
|
||
outputs=outputs,
|
||
attr_code=attr_code)[0]
|
||
return result
|
||
|
||
def grad(self, grad_output):
|
||
inputs = [grad_output, self.indices]
|
||
outputs = [jt.empty(self.weight_shape, grad_output.dtype)]
|
||
attr_code = f"""
|
||
op.jt_name = "embeddingbackward";
|
||
EmbeddingAttr *attr = new EmbeddingAttr();
|
||
attr->numEmbeddings = {self.weight_shape[0]};
|
||
op.op_attr.reset(attr);
|
||
"""
|
||
grad_weight = acl_cmd("EmbeddingBackward",
|
||
inputs=inputs,
|
||
outputs=outputs,
|
||
attr_code=attr_code)[0]
|
||
return None, grad_weight
|
||
|
||
class Embedding(jt.nn.Module):
|
||
|
||
def __init__(self,
|
||
num_embeddings,
|
||
embedding_dim,
|
||
padding_idx=None,
|
||
dtype="float32"):
|
||
self.num_embeddings = num_embeddings
|
||
self.embedding_dim = embedding_dim
|
||
self.padding_idx = padding_idx
|
||
self.weight = jt.init.gauss(
|
||
[self.num_embeddings, self.embedding_dim], dtype)
|
||
if padding_idx is not None:
|
||
self.weight[padding_idx] = 0
|
||
|
||
def execute(self, x):
|
||
res = embedding_acl(x, self.weight)
|
||
return res
|
||
|
||
# class SoftmaxACL(Function):
|
||
|
||
# def __init__(self):
|
||
# super(SoftmaxACL, self).__init__()
|
||
|
||
# def execute(self, x, dim):
|
||
# x = x.float32()
|
||
# inputs = [x]
|
||
# outputs = [jt.empty(x.shape)]
|
||
# self.dim = dim
|
||
# attr_code = f"""
|
||
# op.jt_name = "softmax";
|
||
# SoftmaxAttr *attr = new SoftmaxAttr();
|
||
# attr->dim = {dim};
|
||
# op.op_attr.reset(attr);
|
||
# """
|
||
# result = acl_cmd("Softmax",
|
||
# inputs=inputs,
|
||
# outputs=outputs,
|
||
# attr_code=attr_code)[0]
|
||
# self.output = result
|
||
# return result
|
||
|
||
# def grad(self, grad_output):
|
||
# attr_code = f"""
|
||
# op.jt_name = "softmax";
|
||
# SoftmaxAttr *attr = new SoftmaxAttr();
|
||
# attr->dim = {self.dim};
|
||
# op.op_attr.reset(attr);
|
||
# """
|
||
# inputs = [grad_output, self.output]
|
||
# outputs = [jt.empty(grad_output.shape)]
|
||
# grad_input = acl_cmd("SoftmaxBackward",
|
||
# inputs=inputs,
|
||
# outputs=outputs,
|
||
# attr_code=attr_code)[0]
|
||
# return grad_input
|
||
|
||
from .aclops.softmax_op import SoftmaxACL
|
||
class Softmax(jt.nn.Module):
|
||
|
||
def __init__(self):
|
||
super(Softmax, self).__init__()
|
||
|
||
def execute(self, x, dim):
|
||
return SoftmaxACL()(x, dim)
|
||
|
||
def softmax_acl(x, dim):
|
||
return SoftmaxACL()(x, dim)
|
||
|
||
# class RopeACL(Function):
|
||
|
||
# def __init__(self):
|
||
# super(RopeACL, self).__init__()
|
||
|
||
# def execute(self, xq, xk, freqs_cis, freq_cos, freq_sin):
|
||
# attr_code = f"""
|
||
# op.jt_name = "RotaryPosEmb";
|
||
# """
|
||
# if freqs_cis is not None:
|
||
# freq_cos = freqs_cis[..., 0]
|
||
# freq_sin = freqs_cis[..., 1]
|
||
# else:
|
||
# assert freq_cos is not None and freq_sin is not None
|
||
# inputs = [xq, xk, freq_cos, freq_sin]
|
||
# results = acl_cmd("RotaryPosEmb",
|
||
# inputs,
|
||
# output_dtypes=[
|
||
# xq.dtype,
|
||
# ],
|
||
# output_shapes=[
|
||
# xq.shape,
|
||
# ],
|
||
# attr_code=attr_code)
|
||
# results[0].sync()
|
||
# return inputs[0], inputs[1]
|
||
|
||
# def grad(self, grad_output):
|
||
# return grad_output
|
||
|
||
from .aclops.rope_op import RopeACL
|
||
def rope_acl(xq, xk, freqs_cis=None, freq_sin=None, freq_cos=None):
|
||
return RopeACL()(xq, xk, freqs_cis, freq_sin, freq_cos)
|
||
|
||
class BatchNormACL(Function):
|
||
|
||
def __init__(self,
|
||
num_features,
|
||
eps=1e-05,
|
||
momentum=0.1,
|
||
affine=True,
|
||
is_train=True,
|
||
sync=True):
|
||
self.num_features = num_features
|
||
self.eps = eps
|
||
self.momentum = momentum
|
||
self.affine = affine
|
||
self.is_train = is_train
|
||
self.sync = sync
|
||
self.weight = jt.init.constant(
|
||
(num_features, ), "float32", 1.0) if affine else 1.0
|
||
self.bias = jt.init.constant(
|
||
(num_features, ), "float32", 0.0) if affine else 0.0
|
||
self.running_mean = jt.init.constant((num_features, ), "float32",
|
||
0.0).stop_grad()
|
||
self.running_var = jt.init.constant((num_features, ), "float32",
|
||
1.0).stop_grad()
|
||
|
||
def execute(self, x):
|
||
# assert self.num_features == x.shape[-1]
|
||
self.input = x.float32()
|
||
inputs = [
|
||
self.input, self.weight, self.bias, self.running_mean,
|
||
self.running_var
|
||
]
|
||
outputs = [
|
||
jt.empty(x.shape),
|
||
jt.empty(self.num_features),
|
||
jt.empty(self.num_features)
|
||
]
|
||
attr_code = f"""
|
||
op.jt_name = "batchnorm";
|
||
BatchNormAttr *attr = new BatchNormAttr();
|
||
attr->is_train = {"true" if self.is_train else "false"};
|
||
attr->momentum = {self.momentum};
|
||
attr->eps = {self.eps};
|
||
op.op_attr.reset(attr);
|
||
"""
|
||
result = acl_cmd("BatchNorm",
|
||
inputs=inputs,
|
||
outputs=outputs,
|
||
attr_code=attr_code)
|
||
self.output = result[0]
|
||
self.saveMean = result[1]
|
||
self.saveInvstd = result[2]
|
||
return self.output
|
||
|
||
def grad(self, grad_output):
|
||
attr_code = f"""
|
||
op.jt_name = "batchnorm";
|
||
BatchNormAttr *attr = new BatchNormAttr();
|
||
attr->is_train = {"true" if self.is_train else "false"};
|
||
attr->momentum = {self.momentum};
|
||
attr->eps = {self.eps};
|
||
op.op_attr.reset(attr);
|
||
"""
|
||
inputs = [
|
||
grad_output, self.input, self.weight, self.running_mean,
|
||
self.running_var, self.saveMean, self.saveInvstd
|
||
]
|
||
outputs = [
|
||
jt.empty(self.input.shape),
|
||
jt.empty(self.num_features),
|
||
jt.empty(self.num_features)
|
||
]
|
||
grad_input = acl_cmd("SoftmaxBackward",
|
||
inputs=inputs,
|
||
outputs=outputs,
|
||
attr_code=attr_code)[0]
|
||
return grad_input
|
||
|
||
# class LayerNormACL(Function):
|
||
|
||
# def __init__(self,
|
||
# normalized_shape,
|
||
# eps: float = 1e-5,
|
||
# elementwise_affine: bool = True):
|
||
# if isinstance(normalized_shape, int):
|
||
# normalized_shape = (normalized_shape, )
|
||
# self.normalized_shape = tuple(normalized_shape)
|
||
# self.eps = eps
|
||
# self.elementwise_affine = elementwise_affine
|
||
# self.weight = jt.init.constant(normalized_shape, "float32",
|
||
# 1.0) if elementwise_affine else 1.0
|
||
# self.bias = jt.init.constant(normalized_shape, "float32",
|
||
# 0.0) if elementwise_affine else 0.0
|
||
|
||
# def execute(self, x):
|
||
# self.input = x.float32()
|
||
# inputs = [self.input, self.weight, self.bias]
|
||
# outputs = [jt.empty(x.shape), jt.empty(x.shape), jt.empty(x.shape)]
|
||
# attr_code = f"""
|
||
# op.jt_name = "layernorm";
|
||
# LayerNormAttr *attr = new LayerNormAttr();
|
||
# attr->eps = {self.eps};
|
||
# attr->normalizedShape = {{{', '.join(map(str, (list(self.normalized_shape))))}}};
|
||
# attr->size = {x.shape[-1]};
|
||
# op.op_attr.reset(attr);
|
||
# """
|
||
# result = acl_cmd("LayerNorm",
|
||
# inputs=inputs,
|
||
# outputs=outputs,
|
||
# attr_code=attr_code)
|
||
# self.output = result[0]
|
||
# self.meanout = result[1]
|
||
# self.rstdout = result[2]
|
||
# return self.output
|
||
|
||
# def grad(self, grad_output):
|
||
# attr_code = f"""
|
||
# op.jt_name = "batchnorm";
|
||
# BatchNormAttr *attr = new BatchNormAttr();
|
||
# attr->is_train = {"true" if self.is_train else "false"};
|
||
# attr->momentum = {self.momentum};
|
||
# attr->eps = {self.eps};
|
||
# op.op_attr.reset(attr);
|
||
# """
|
||
# inputs = [grad_output, self.input, self.weight, self.running_mean, self.running_var, self.saveMean, self.saveInvstd]
|
||
# outputs = [jt.empty(self.input.shape), jt.empty(self.num_features), jt.empty(self.num_features)]
|
||
# grad_input = acl_cmd("SoftmaxBackward",
|
||
# inputs=inputs,
|
||
# outputs=outputs,
|
||
# attr_code=attr_code)[0]
|
||
# return grad_input
|
||
|
||
# class StackACL(Function):
|
||
|
||
# def __init__(self):
|
||
# super(StackACL, self).__init__()
|
||
|
||
# def execute(self, input_tensors, dim):
|
||
# if type(input_tensors) is tuple:
|
||
# input_tensors = list(input_tensors)
|
||
# assert type(input_tensors) is list
|
||
# assert -1 * len(input_tensors) - 1 <= dim and dim <= len(
|
||
# input_tensors)
|
||
# for i in range(len(input_tensors)):
|
||
# if input_tensors[i].dtype != input_tensors[0].dtype:
|
||
# raise ValueError(
|
||
# "All input tensors must have the same dtype")
|
||
# if input_tensors[i].shape != input_tensors[0].shape:
|
||
# raise ValueError(
|
||
# "All input tensors must have the same shape")
|
||
# self.input = input_tensors
|
||
# input_shape = list(input_tensors[0].shape)
|
||
# output_shape = input_shape[:dim] + [len(input_tensors)
|
||
# ] + input_shape[dim:]
|
||
# attr_code = f"""
|
||
# op.jt_name = "stack";
|
||
# ConcatAttr *attr = new ConcatAttr();
|
||
# attr->tensorNum = {len(input_tensors)};
|
||
# attr->dim = {dim};
|
||
# op.op_attr.reset(attr);
|
||
# """
|
||
# self.attr_code = attr_code
|
||
# result = acl_cmd("Stack",
|
||
# input_tensors,
|
||
# output_dtypes=[input_tensors[0].dtype],
|
||
# output_shapes=[output_shape],
|
||
# attr_code=self.attr_code)[0]
|
||
# return result
|
||
|
||
# def grad(self, grad_output):
|
||
# grad_inputs = self.split_grad(grad_output, self.input, self.dim)
|
||
# return grad_inputs
|
||
|
||
# def split_grad(self, grad_output, input_tensors, axis):
|
||
# offset = []
|
||
# shapeVec = []
|
||
# dtypeVec = []
|
||
# for tensor in input_tensors:
|
||
# offset.append(tensor.shape[axis])
|
||
# dtypeVec.append(tensor.dtype)
|
||
# shapeVec.append(tensor.shape)
|
||
|
||
# attr_code = f"""
|
||
# op.jt_name = "splitwithsize";
|
||
# auto *attr = new SplitWithSizeAttr();
|
||
# attr->splitSize = {{ {", ".join(map(str, offset))} }};
|
||
# attr->dim = {axis};
|
||
# op.op_attr.reset(attr);
|
||
# """
|
||
|
||
# result = acl_cmd("SplitWithSize", [grad_output],
|
||
# output_dtypes=dtypeVec,
|
||
# output_shapes=shapeVec,
|
||
# attr_code=attr_code)
|
||
# return result
|
||
|
||
from .aclops.stack_op import StackACL
|
||
def stack_acl(x, dim=0):
|
||
return StackACL()(x, dim)
|
||
|
||
# class NanToNumACL(Function):
|
||
|
||
# def __init__(self):
|
||
# super(NanToNumACL, self).__init__()
|
||
|
||
# def execute(self, input, nan_or_inf):
|
||
# attr_code = f"""
|
||
# op.jt_name = "NanToNum";
|
||
# NanToNumAttr *attr = new NanToNumAttr();
|
||
# attr->nan = {nan_or_inf};
|
||
# attr->posinf = {-nan_or_inf};
|
||
# attr->neginf = {-nan_or_inf};
|
||
# op.op_attr.reset(attr);
|
||
# """
|
||
# self.attr_code = attr_code
|
||
# result = acl_cmd("NanToNum", [input],
|
||
# output_dtypes=[input[0].dtype],
|
||
# output_shapes=[input.shape],
|
||
# attr_code=self.attr_code)[0]
|
||
# return result
|
||
|
||
from .aclops.nantonum_op import NanToNumACL
|
||
|
||
def isnan_acl(x):
|
||
tonum = NanToNumACL()(x, -1.0)
|
||
return jt.not_equal(x, tonum).logical_and(
|
||
jt.not_equal(tonum, jt.ones_like(x)))
|
||
|
||
def isinf_acl(x):
|
||
tonum = NanToNumACL()(x, 1.0)
|
||
return jt.not_equal(x, tonum).logical_and(
|
||
jt.not_equal(tonum, jt.ones_like(x)))
|
||
|
||
def warp(origin_func, new_func, name=None):
|
||
|
||
if isinstance(origin_func, type):
|
||
|
||
class WrappedClass(origin_func, new_func):
|
||
|
||
def __init__(self, *args, **kwargs):
|
||
if jt.flags.use_acl:
|
||
new_func.__init__(self, *args, **kwargs)
|
||
else:
|
||
origin_func.__init__(self, *args, **kwargs)
|
||
|
||
def execute(self, *args, **kwargs):
|
||
if jt.flags.use_acl:
|
||
return new_func.execute(self, *args, **kwargs)
|
||
elif name == 'setitem':
|
||
return args[0].assign(origin_func(*args, **kwargs))
|
||
else:
|
||
return origin_func.execute(self, *args, **kwargs)
|
||
|
||
return WrappedClass
|
||
|
||
else:
|
||
|
||
def warpper(*args, **kwargs):
|
||
if jt.flags.use_acl:
|
||
return new_func(*args, **kwargs)
|
||
elif name == 'setitem':
|
||
return args[0].assign(origin_func(*args, **kwargs))
|
||
else:
|
||
return origin_func(*args, **kwargs)
|
||
|
||
return warpper
|
||
|
||
jt.triu = warp(jt.triu, triu_acl)
|
||
jt.triu_ = warp(jt.triu, triu_acl)
|
||
jt.Var.triu = jt.triu
|
||
jt.Var.triu_ = lambda x, diagonal=0: x.assign(x.triu(diagonal))
|
||
jt.nn.conv2d = warp(jt.nn.conv2d, conv_acl)
|
||
jt.nn.Conv2d = warp(jt.nn.Conv2d, Conv2D)
|
||
jt.nn.Conv = warp(jt.nn.Conv, Conv2D)
|
||
|
||
from .aclops.pool_op import PoolACL
|
||
jt.nn.Pool = warp(jt.nn.Pool, PoolACL)
|
||
|
||
jt.flip = warp(jt.flip, flip_acl)
|
||
jt.Var.flip = lambda x, dim_vector=0: jt.flip(x, dim_vector)
|
||
jt.concat = warp(jt.concat, concat)
|
||
jt.stack = warp(jt.stack, stack_acl)
|
||
|
||
jt.gather = warp(jt.gather, gather_acl)
|
||
jt.any = warp(jt.any, any_acl)
|
||
jt.Var.any = jt.any
|
||
|
||
jt.cumsum = warp(jt.cumsum, cumsum_acl)
|
||
jt.cub_cumsum = jt.cumsum
|
||
jt.Var.cumsum = jt.cumsum
|
||
jt.Var.cub_cumsum = jt.cumsum
|
||
|
||
jt.cumprod = warp(jt.cumprod, cumprod_acl)
|
||
jt.index = warp(jt.index, index_acl)
|
||
jt.Var.index = jt.index
|
||
|
||
jt.scatter = warp(jt.scatter, scatter_acl)
|
||
jt.Var.scatter = lambda x, dim, index, src, reduce="void": jt.scatter(
|
||
x, dim, index, src, reduce)
|
||
|
||
jt.where = warp(jt.where, where_acl)
|
||
jt.nonzero = warp(jt.nonzero, nonzero_acl)
|
||
jt.misc.nonzero = warp(jt.misc.nonzero, nonzero_acl)
|
||
jt.Var.nonzero = jt.misc.nonzero
|
||
jt.floor_int = warp(jt.floor_int, floor_int_acl)
|
||
jt.Var.floor_int = lambda x: jt.floor_int(x)
|
||
|
||
jt.getitem = warp(jt.contrib.getitem, getitem_acl)
|
||
fake_getitem = jt.Var.getitem
|
||
jt.Var.getitem = lambda x, slices, return_x=None: warp(
|
||
fake_getitem, getitem_acl)(x, slices)
|
||
jt.Var.slice_var = lambda x, slices, return_x=None: warp(
|
||
fake_getitem, getitem_acl)(x, slices)
|
||
jt.Var.__getitem__ = lambda x, slices, return_x=None: warp(
|
||
fake_getitem, getitem_acl)(x, slices)
|
||
|
||
jt.setitem = warp(jt.contrib.setitem, setitem_acl)
|
||
fake_setitem = jt.Var.setitem
|
||
jt.Var.setitem = lambda x, slices, value: warp(
|
||
fake_setitem, setitem_acl, name='setitem')(x, slices, value)
|
||
jt.Var.__setitem__ = lambda x, slices, value: warp(
|
||
fake_setitem, setitem_acl, name='setitem')(x, slices, value)
|
||
|
||
jt.nn.bmm = warp(jt.nn.bmm, bmm_acl)
|
||
jt.bmm = warp(jt.bmm, bmm_acl)
|
||
jt.nn.matmul = warp(jt.matmul, matmul_acl)
|
||
jt.matmul = warp(jt.matmul, matmul_acl)
|
||
jt.nn.matmul_transpose = warp(jt.nn.matmul_transpose, matmul_transpose_acl)
|
||
jt.nn.bmm_transpose = warp(jt.nn.bmm_transpose, bmm_transpose_acl)
|
||
jt.bmm_transpose = warp(jt.bmm_transpose, bmm_transpose_acl)
|
||
|
||
jt.transpose = warp(jt.transpose, transpose_acl)
|
||
fake_transpose = jt.transpose
|
||
jt.Var.transpose = lambda x, *dim: warp(fake_transpose, transpose_acl)(x, *
|
||
dim)
|
||
# jt.Var.permute = lambda x: warp(fake_transpose, transpose_acl)(x)
|
||
# jt.Var.t = lambda x: warp(fake_transpose, transpose_acl)(x)
|
||
|
||
jt.nn.relu = warp(jt.nn.relu, relu)
|
||
jt.nn.ReLU = warp(jt.nn.ReLU, ReLU)
|
||
|
||
jt.nn.leaky_relu = warp(jt.nn.leaky_relu, leaky_relu)
|
||
jt.nn.LeakyReLU = warp(jt.nn.LeakyReLU, LeakyReLU)
|
||
|
||
# jt.nn.silu = warp(jt.nn.silu, silu_acl)
|
||
# jt.nn.SiLU = warp(jt.nn.SiLU, SiLU)
|
||
|
||
jt.sigmoid = warp(jt.sigmoid, sigmoid_acl)
|
||
jt.nn.Sigmoid = warp(jt.nn.Sigmoid, Sigmoid)
|
||
|
||
# def embedding_acl(indices, weight):
|
||
# return EmbeddingACL()(indices, weight)
|
||
|
||
# jt.nn.embedding = warp(jt.nn.embedding, embedding_acl)
|
||
# jt.nn.Embedding = warp(jt.nn.Embedding, Embedding)
|
||
jt.nn.dropout = warp(jt.nn.dropout, dropout_acl)
|
||
jt.nn.Dropout = warp(jt.nn.Dropout, Dropout)
|
||
|
||
jt.nn.softmax = warp(jt.nn.softmax, softmax_acl)
|
||
|
||
# jt.nn.BatchNorm = warp(jt.nn.BatchNorm, BatchNormACL)
|
||
# jt.nn.LayerNorm = warp(jt.nn.LayerNorm, LayerNormACL)
|
||
from .aclops.flashattention_op import FlashAttentionACL
|
||
jt.nn.FlashAttention = warp(jt.nn.FlashAttention, FlashAttentionACL)
|
||
jt.isnan = warp(jt.isnan, isnan_acl)
|
||
jt.isinf = warp(jt.isinf, isinf_acl)
|
||
jt.Var.isnan = jt.isnan
|
||
jt.Var.isinf = jt.isinf
|
||
|
||
jt.nn.rotary_emb = rope_acl
|