mirror of https://github.com/Jittor/Jittor
fp16 support, opbytype interface
This commit is contained in:
parent
81b847e6f0
commit
bcb57086c3
|
@ -1235,6 +1235,7 @@ files4 = [ f[len(jittor_path)+1:] for f in files4 ]
|
||||||
# files4 = run_cmd('find -L src | grep '+grep_args, jittor_path).splitlines()
|
# files4 = run_cmd('find -L src | grep '+grep_args, jittor_path).splitlines()
|
||||||
at_beginning = [
|
at_beginning = [
|
||||||
"src/ops/op_utils.cc",
|
"src/ops/op_utils.cc",
|
||||||
|
"src/ops/op_register.cc",
|
||||||
"src/event_queue.cc",
|
"src/event_queue.cc",
|
||||||
"src/mem/allocator/sfrl_allocator.cc",
|
"src/mem/allocator/sfrl_allocator.cc",
|
||||||
"src/mem/allocator.cc",
|
"src/mem/allocator.cc",
|
||||||
|
|
|
@ -60,14 +60,14 @@ class Pool(Module):
|
||||||
'''
|
'''
|
||||||
if not self.return_indices:
|
if not self.return_indices:
|
||||||
forward_body += f'''
|
forward_body += f'''
|
||||||
@out(i0, i1, i2, i3) = init_{self.op}(out_type);
|
@out(i0, i1, i2, i3) = @expand_op(init_{self.op}, @out_type);
|
||||||
for (int p = k2; p < k2_; ++p)
|
for (int p = k2; p < k2_; ++p)
|
||||||
for (int q = k3; q < k3_; ++q)
|
for (int q = k3; q < k3_; ++q)
|
||||||
@out(i0, i1, i2, i3) = {self.op}(out_type, @out(i0, i1, i2, i3), @in0(i0, i1, p, q));
|
@out(i0, i1, i2, i3) = @expand_op({self.op}, @out_type, @out(i0, i1, i2, i3), @out_type, @in0(i0, i1, p, q), @in0_type);
|
||||||
'''
|
'''
|
||||||
else:
|
else:
|
||||||
forward_body += f'''
|
forward_body += f'''
|
||||||
auto out_value = init_{self.op}(out_type);
|
auto out_value = @expand_op(init_{self.op}, @out_type);
|
||||||
int out_index = -1;
|
int out_index = -1;
|
||||||
for (int p = k2; p < k2_; ++p)
|
for (int p = k2; p < k2_; ++p)
|
||||||
for (int q = k3; q < k3_; ++q)
|
for (int q = k3; q < k3_; ++q)
|
||||||
|
@ -105,7 +105,6 @@ class Pool(Module):
|
||||||
return_dtypes = x.dtype
|
return_dtypes = x.dtype
|
||||||
out = jt.code(return_shapes, return_dtypes, [x],
|
out = jt.code(return_shapes, return_dtypes, [x],
|
||||||
cuda_header="""
|
cuda_header="""
|
||||||
#include <ops/binary_op_defs.h>
|
|
||||||
#include <misc/cuda_limits.h>
|
#include <misc/cuda_limits.h>
|
||||||
""",
|
""",
|
||||||
cuda_src=f'''
|
cuda_src=f'''
|
||||||
|
@ -153,7 +152,7 @@ class Pool(Module):
|
||||||
dim3 s2_(tx, ty);
|
dim3 s2_(tx, ty);
|
||||||
kernel3<<<s1_, s2_>>>(@ARGS);
|
kernel3<<<s1_, s2_>>>(@ARGS);
|
||||||
'''],
|
'''],
|
||||||
cpu_header='#include <ops/binary_op_defs.h>',
|
cpu_header='',
|
||||||
cpu_src=f'''
|
cpu_src=f'''
|
||||||
using namespace std;
|
using namespace std;
|
||||||
for (int i0=0; i0<out_shape0; i0++)
|
for (int i0=0; i0<out_shape0; i0++)
|
||||||
|
@ -242,15 +241,15 @@ class Pool3d(Module):
|
||||||
'''
|
'''
|
||||||
if not self.return_indices:
|
if not self.return_indices:
|
||||||
forward_body += f'''
|
forward_body += f'''
|
||||||
@out(i0, i1, i2, i3, i4) = init_{self.op}(out_type);
|
@out(i0, i1, i2, i3, i4) = @expand_op(init_{self.op}, @out_type);
|
||||||
for (int p = k2; p < k2_; ++p)
|
for (int p = k2; p < k2_; ++p)
|
||||||
for (int q = k3; q < k3_; ++q)
|
for (int q = k3; q < k3_; ++q)
|
||||||
for (int r = k4; r < k4_; ++r)
|
for (int r = k4; r < k4_; ++r)
|
||||||
@out(i0, i1, i2, i3, i4) = {self.op}(out_type, @out(i0, i1, i2, i3, i4), @in0(i0, i1, p, q, r));
|
@out(i0, i1, i2, i3, i4) = @expand_op({self.op}, @out_type, @out(i0, i1, i2, i3, i4), @out_type, @in0(i0, i1, p, q, r), @in0_type);
|
||||||
'''
|
'''
|
||||||
else:
|
else:
|
||||||
forward_body += f'''
|
forward_body += f'''
|
||||||
auto out_value = init_{self.op}(out_type);
|
auto out_value = @expand_op(init_{self.op}, @out_type);
|
||||||
int out_index = -1;
|
int out_index = -1;
|
||||||
for (int p = k2; p < k2_; ++p)
|
for (int p = k2; p < k2_; ++p)
|
||||||
for (int q = k3; q < k3_; ++q)
|
for (int q = k3; q < k3_; ++q)
|
||||||
|
@ -293,7 +292,6 @@ class Pool3d(Module):
|
||||||
return_dtypes = x.dtype
|
return_dtypes = x.dtype
|
||||||
out = jt.code(return_shapes, return_dtypes, [x],
|
out = jt.code(return_shapes, return_dtypes, [x],
|
||||||
cuda_header="""
|
cuda_header="""
|
||||||
#include <ops/binary_op_defs.h>
|
|
||||||
#include <misc/cuda_limits.h>
|
#include <misc/cuda_limits.h>
|
||||||
""",
|
""",
|
||||||
cuda_src=f'''
|
cuda_src=f'''
|
||||||
|
@ -349,7 +347,7 @@ class Pool3d(Module):
|
||||||
dim3 s2(tx, ty, tz);
|
dim3 s2(tx, ty, tz);
|
||||||
kernel3<<<s1, s2>>>(@ARGS);
|
kernel3<<<s1, s2>>>(@ARGS);
|
||||||
'''],
|
'''],
|
||||||
cpu_header='#include <ops/binary_op_defs.h>',
|
cpu_header='',
|
||||||
cpu_src=f'''
|
cpu_src=f'''
|
||||||
using namespace std;
|
using namespace std;
|
||||||
for (int i0=0; i0<out_shape0; i0++)
|
for (int i0=0; i0<out_shape0; i0++)
|
||||||
|
|
|
@ -223,6 +223,16 @@ void load_macros(const string& src, unordered_map<string,string>& macros) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
string expand_op_search(const vector<string>& args) {
|
||||||
|
for (auto op_type : op_types) {
|
||||||
|
string ret = op_type->expand_op(args);
|
||||||
|
if (ret.size())
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
LOGf << "No expand op pattern found for args:" << args;
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
|
||||||
void expand_macro(const string& macro, const vector<string>& args, string& new_src) {
|
void expand_macro(const string& macro, const vector<string>& args, string& new_src) {
|
||||||
LOGvvvv << "expand_macro" << macro << "args:" << args;
|
LOGvvvv << "expand_macro" << macro << "args:" << args;
|
||||||
if (macro.size() == 0 || macro[0] != '<') {
|
if (macro.size() == 0 || macro[0] != '<') {
|
||||||
|
@ -434,6 +444,7 @@ string precompile(unordered_map<string,string> defs, string src, unordered_map<s
|
||||||
vector<string> args;
|
vector<string> args;
|
||||||
size_t l = k+1;
|
size_t l = k+1;
|
||||||
if (expr == "for" || expr == "if" || expr == "expand_macro" ||
|
if (expr == "for" || expr == "if" || expr == "expand_macro" ||
|
||||||
|
expr == "expand_op" ||
|
||||||
expr == "is_def" || expr == "python" ||
|
expr == "is_def" || expr == "python" ||
|
||||||
(k<src.size() && src[k]=='(')) {
|
(k<src.size() && src[k]=='(')) {
|
||||||
ASSERT(src[k] == '(');
|
ASSERT(src[k] == '(');
|
||||||
|
@ -555,6 +566,18 @@ string precompile(unordered_map<string,string> defs, string src, unordered_map<s
|
||||||
i = l-1;
|
i = l-1;
|
||||||
continue;
|
continue;
|
||||||
} else
|
} else
|
||||||
|
if (expr == "expand_op") {
|
||||||
|
// syntax: @expand_op(args)
|
||||||
|
for (auto& arg : args) {
|
||||||
|
uint p=0;
|
||||||
|
while (p<arg.size() && arg[p] == ' ') p++;
|
||||||
|
arg = precompile(defs, arg.substr(p), macros);
|
||||||
|
}
|
||||||
|
string ns = expand_op_search(args);
|
||||||
|
new_src += precompile(defs, ns, macros);
|
||||||
|
i = l-1;
|
||||||
|
continue;
|
||||||
|
} else
|
||||||
if (expr == "define") {
|
if (expr == "define") {
|
||||||
// syntax: @define(macro, value)
|
// syntax: @define(macro, value)
|
||||||
// ij k l
|
// ij k l
|
||||||
|
@ -846,6 +869,9 @@ string OpCompiler::__get_fused_src(
|
||||||
};
|
};
|
||||||
auto not_change = [&](const string& s) -> bool {
|
auto not_change = [&](const string& s) -> bool {
|
||||||
if (unchanged.count(s)) return true;
|
if (unchanged.count(s)) return true;
|
||||||
|
for (auto op_type : op_types)
|
||||||
|
if (op_type->types.count(s))
|
||||||
|
return true;
|
||||||
return (s.find("::") != string::npos) || (s.find("LOG") != string::npos);
|
return (s.find("::") != string::npos) || (s.find("LOG") != string::npos);
|
||||||
};
|
};
|
||||||
// regex find XxxXxxOp::jit_run
|
// regex find XxxXxxOp::jit_run
|
||||||
|
|
|
@ -8,7 +8,6 @@
|
||||||
#include "var.h"
|
#include "var.h"
|
||||||
#include "ops/binary_op.h"
|
#include "ops/binary_op.h"
|
||||||
#include "ops/broadcast_to_op.h"
|
#include "ops/broadcast_to_op.h"
|
||||||
#include "ops/binary_op_defs.h"
|
|
||||||
#include "ops/op_register.h"
|
#include "ops/op_register.h"
|
||||||
|
|
||||||
namespace jittor {
|
namespace jittor {
|
||||||
|
@ -554,7 +553,7 @@ void BinaryOp::jit_run() {
|
||||||
auto* __restrict__ zp = z->ptr<Tz>();
|
auto* __restrict__ zp = z->ptr<Tz>();
|
||||||
index_t num = z->num;
|
index_t num = z->num;
|
||||||
for (index_t i=0; i<num; i++)
|
for (index_t i=0; i<num; i++)
|
||||||
zp[i] = @expand_macro(@OP, Tz, xp[i], yp[i]);
|
zp[i] = @expand_op(@OP, @Tz, xp[i], @Tx, yp[i], @Ty);
|
||||||
}
|
}
|
||||||
#endif // JIT
|
#endif // JIT
|
||||||
|
|
||||||
|
|
|
@ -1,61 +0,0 @@
|
||||||
// ***************************************************************
|
|
||||||
// Copyright (c) 2021 Jittor. All Rights Reserved.
|
|
||||||
// Maintainers: Dun Liang <randonlang@gmail.com>.
|
|
||||||
// This file is subject to the terms and conditions defined in
|
|
||||||
// file 'LICENSE.txt', which is part of this source code package.
|
|
||||||
// ***************************************************************
|
|
||||||
#pragma once
|
|
||||||
#include "common.h"
|
|
||||||
|
|
||||||
namespace jittor {
|
|
||||||
|
|
||||||
#ifdef JIT_cuda
|
|
||||||
#define pow(T,a,b) ::pow(a,b)
|
|
||||||
#define maximum(T,a,b) ::max(T(a), T(b))
|
|
||||||
#define minimum(T,a,b) ::min(T(a), T(b))
|
|
||||||
#define mod(T,a,b) @if(@strcmp(@T,float32)==0,(a-::floorf((a)/(b))*(b)),@if(@strcmp(@Tx,float64)==0,(a-::floor((a)/(b))*(b)),(a%b)))
|
|
||||||
#else // JIT_cpu
|
|
||||||
#define pow(T,a,b) std::pow(a,b)
|
|
||||||
#define maximum(T,a,b) std::max(T(a), T(b))
|
|
||||||
#define minimum(T,a,b) std::min(T(a), T(b))
|
|
||||||
#define mod(T,a,b) @if(@strcmp(@T,float32)==0,(a-std::floor((a)/(b))*(b)),@if(@strcmp(@Tx,float64)==0,(a-std::floor((a)/(b))*(b)),(a%b)))
|
|
||||||
#endif
|
|
||||||
#define add(T,a,b) ((a)+(b))
|
|
||||||
#define subtract(T,a,b) ((a)-(b))
|
|
||||||
#define multiply(T,a,b) ((a)*(b))
|
|
||||||
#define divide(T,a,b) (T((T(a))/(T(b))))
|
|
||||||
#define floor_divide(T,a,b) (T((T(a))/(T(b))))
|
|
||||||
#define less(T,a,b) ((a)<(b))
|
|
||||||
#define less_equal(T,a,b) ((a)<=(b))
|
|
||||||
#define greater(T,a,b) ((a)>(b))
|
|
||||||
#define greater_equal(T,a,b) ((a)>=(b))
|
|
||||||
#define equal(T,a,b) ((a)==(b))
|
|
||||||
#define not_equal(T,a,b) ((a)!=(b))
|
|
||||||
#define left_shift(T,a,b) ((a)<<(b))
|
|
||||||
#define right_shift(T,a,b) ((a)>>(b))
|
|
||||||
#define logical_and(T,a,b) ((a)&&(b))
|
|
||||||
#define logical_or(T,a,b) ((a)||(b))
|
|
||||||
#define logical_xor(T,a,b) ((bool(a))!=(bool(b)))
|
|
||||||
#define bitwise_and(T,a,b) ((a)&(b))
|
|
||||||
#define bitwise_or(T,a,b) ((a)|(b))
|
|
||||||
#define bitwise_xor(T,a,b) ((a)^(b))
|
|
||||||
#define mean(T,a,b) ((a)+T(b)*(T(rcount)))
|
|
||||||
|
|
||||||
#ifdef JIT_cuda
|
|
||||||
#define init_maximum(T) ::numeric_min<T>()
|
|
||||||
#define init_minimum(T) ::numeric_max<T>()
|
|
||||||
#else
|
|
||||||
#define init_maximum(T) std::numeric_limits<T>::lowest()
|
|
||||||
#define init_minimum(T) std::numeric_limits<T>::max()
|
|
||||||
#endif
|
|
||||||
#define init_add(T) T(0)
|
|
||||||
#define init_multiply(T) T(1)
|
|
||||||
#define init_logical_and(T) true
|
|
||||||
#define init_logical_or(T) false
|
|
||||||
#define init_logical_xor(T) false
|
|
||||||
#define init_bitwise_and(T) T(-1)
|
|
||||||
#define init_bitwise_or(T) T(0)
|
|
||||||
#define init_bitwise_xor(T) T(0)
|
|
||||||
#define init_mean(T) T(0)
|
|
||||||
|
|
||||||
} // jittor
|
|
|
@ -0,0 +1,159 @@
|
||||||
|
// ***************************************************************
|
||||||
|
// Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||||
|
// Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||||
|
// This file is subject to the terms and conditions defined in
|
||||||
|
// file 'LICENSE.txt', which is part of this source code package.
|
||||||
|
// ***************************************************************
|
||||||
|
#include "common.h"
|
||||||
|
#include "utils/str_utils.h"
|
||||||
|
#include "ops/op_register.h"
|
||||||
|
|
||||||
|
namespace jittor {
|
||||||
|
|
||||||
|
extern int use_cuda;
|
||||||
|
|
||||||
|
struct CommonOpType : OpByType {
|
||||||
|
CommonOpType() {
|
||||||
|
types = {
|
||||||
|
"bool",
|
||||||
|
"int8",
|
||||||
|
"int16",
|
||||||
|
"int32",
|
||||||
|
"int64",
|
||||||
|
"uint8",
|
||||||
|
"uint16",
|
||||||
|
"uint32",
|
||||||
|
"uint64",
|
||||||
|
"float32",
|
||||||
|
"float64",
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
string expand_op(const vector<string>& args) {
|
||||||
|
for (int i=1; i<args.size(); i+=2) {
|
||||||
|
if (!types.count(args[i]))
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
static unordered_map<string,string> cuda_map = {
|
||||||
|
{"logical_not", "(!($2))"},
|
||||||
|
{"bitwise_not", "(~($2))"},
|
||||||
|
{"negative", "(-($2))"},
|
||||||
|
{"abs", "::abs($2)"},
|
||||||
|
{"log", "::logf(($1)($2))"},
|
||||||
|
{"exp", "::expf(($1)($2))"},
|
||||||
|
{"sqrt", "::sqrtf(($1)($2))"},
|
||||||
|
{"round", "(($1) ::roundf(($2)))"},
|
||||||
|
{"floor", "(($1) ::floorf(($2)))"},
|
||||||
|
{"ceil", "(($1) ::ceilf(($2)))"},
|
||||||
|
{"round_int", "(($1) ::roundf(($2)))"},
|
||||||
|
{"floor_int", "(($1) ::floorf(($2)))"},
|
||||||
|
{"ceil_int", "(($1) ::ceilf(($2)))"},
|
||||||
|
{"sin", "(($1) ::sinf(($2)))"},
|
||||||
|
{"asin", "(($1) ::asinf(($2)))"},
|
||||||
|
{"sinh", "(($1) ::sinhf(($2)))"},
|
||||||
|
{"asinh", "(($1) ::asinhf(($2)))"},
|
||||||
|
{"cos", "(($1) ::cosf(($2)))"},
|
||||||
|
{"acos", "(($1) ::acosf(($2)))"},
|
||||||
|
{"cosh", "(($1) ::coshf(($2)))"},
|
||||||
|
{"acosh", "(($1) ::acoshf(($2)))"},
|
||||||
|
{"tan", "(($1) ::tanf(($2)))"},
|
||||||
|
{"atan", "(($1) ::atanf(($2)))"},
|
||||||
|
{"tanh", "(($1) ::tanhf(($2)))"},
|
||||||
|
{"atanh", "(($1) ::atanhf(($2)))"},
|
||||||
|
{"sigmoid", "(($1) (1.0f/(1.0f+::expf((::min($1(-($2)), $1(@if(@strcmp($1,float32)==0,30,300))))))))"},
|
||||||
|
{"erf", "(($1) ::erff(($2)))"},
|
||||||
|
{"erfinv", "(($1) ::erfinvf(($1)($2)))"},
|
||||||
|
{"cast", "(($1)($2))"},
|
||||||
|
{"pow", "::pow(($2),($4))"},
|
||||||
|
{"maximum", "::max($1($2), $1($4))"},
|
||||||
|
{"minimum", "::min($1($2), $1($4))"},
|
||||||
|
{"mod", "@if(@strcmp($1,float32)==0,(($2)-::floorf(($2)/($4))*($4)),@if(@strcmp(@Tx,float64)==0,(($2)-::floor(($2)/($4))*($4)),(($2)%($4))))"},
|
||||||
|
{"init_maximum", "::numeric_min<$1>()"},
|
||||||
|
{"init_minimum", "::numeric_max<$1>()"},
|
||||||
|
};
|
||||||
|
|
||||||
|
static unordered_map<string,string> cpu_map = {
|
||||||
|
{"logical_not", "(!($2))"},
|
||||||
|
{"bitwise_not", "(~($2))"},
|
||||||
|
{"negative", "(-($2))"},
|
||||||
|
{"abs", "std::abs($2)"},
|
||||||
|
{"log", "std::log(($1)($2))"},
|
||||||
|
{"exp", "std::exp(($1)($2))"},
|
||||||
|
{"sqrt", "std::sqrt(($1)($2))"},
|
||||||
|
{"round", "(($1)std::round(($2)))"},
|
||||||
|
{"floor", "(($1)std::floor(($2)))"},
|
||||||
|
{"ceil", "(($1)std::ceil(($2)))"},
|
||||||
|
{"round_int", "(($1)std::round(($2)))"},
|
||||||
|
{"floor_int", "(($1)std::floor(($2)))"},
|
||||||
|
{"ceil_int", "(($1)std::ceil(($2)))"},
|
||||||
|
{"sin", "(($1) std::sin(($2)))"},
|
||||||
|
{"asin", "(($1) std::asin(($2)))"},
|
||||||
|
{"sinh", "(($1) std::sinh(($2)))"},
|
||||||
|
{"asinh", "(($1) std::asinh(($2)))"},
|
||||||
|
{"cos", "(($1) std::cos(($2)))"},
|
||||||
|
{"acos", "(($1) std::acos(($2)))"},
|
||||||
|
{"cosh", "(($1) std::cosh(($2)))"},
|
||||||
|
{"acosh", "(($1) std::acosh(($2)))"},
|
||||||
|
{"tan", "(($1) std::tan(($2)))"},
|
||||||
|
{"atan", "(($1) std::atan(($2)))"},
|
||||||
|
{"tanh", "(($1) std::tanh(($2)))"},
|
||||||
|
{"atanh", "(($1) std::atanh(($2)))"},
|
||||||
|
{"sigmoid", "(($1) (1.0f/(1.0f+std::exp(std::min($1(-($2)), $1(@if(@strcmp($1,float32)==0,30,300)))))))"},
|
||||||
|
{"erf", "(($1) std::erf(($2)))"},
|
||||||
|
{"erfinv", "(jittor::_erfinv($2))"},
|
||||||
|
{"cast", "(($1)($2))"},
|
||||||
|
{"pow", "std::pow(($2),($4))"},
|
||||||
|
{"maximum", "std::max($1($2), $1($4))"},
|
||||||
|
{"minimum", "std::min($1($2), $1($4))"},
|
||||||
|
{"mod", "@if(@strcmp($1,float32)==0,(($2)-std::floor(($2)/($4))*($4)),@if(@strcmp(@Tx,float64)==0,(($2)-std::floor(($2)/($4))*($4)),(($2)%($4))))"},
|
||||||
|
{"init_maximum", "std::numeric_limits<$1>::lowest()"},
|
||||||
|
{"init_minimum", "std::numeric_limits<$1>::max()"},
|
||||||
|
};
|
||||||
|
|
||||||
|
static unordered_map<string,string> both_map {
|
||||||
|
{"add", "(($2)+($4))"},
|
||||||
|
{"subtract", "(($2)-($4))"},
|
||||||
|
{"multiply", "(($2)*($4))"},
|
||||||
|
{"divide", "($1(($1($2))/($1($4))))"},
|
||||||
|
{"floor_divide", "($1(($1($2))/($1($4))))"},
|
||||||
|
{"less", "(($2)<($4))"},
|
||||||
|
{"less_equal", "(($2)<=($4))"},
|
||||||
|
{"greater", "(($2)>($4))"},
|
||||||
|
{"greater_equal", "(($2)>=($4))"},
|
||||||
|
{"equal", "(($2)==($4))"},
|
||||||
|
{"not_equal", "(($2)!=($4))"},
|
||||||
|
{"left_shift", "(($2)<<($4))"},
|
||||||
|
{"right_shift", "(($2)>>($4))"},
|
||||||
|
{"logical_and", "(($2)&&($4))"},
|
||||||
|
{"logical_or", "(($2)||($4))"},
|
||||||
|
{"logical_xor", "((bool($2))!=(bool($4)))"},
|
||||||
|
{"bitwise_and", "(($2)&($4))"},
|
||||||
|
{"bitwise_or", "(($2)|($4))"},
|
||||||
|
{"bitwise_xor", "(($2)^($4))"},
|
||||||
|
{"mean", "(($2)+$1($4)*($1(rcount)))"},
|
||||||
|
{"init_add", "$1(0)"},
|
||||||
|
{"init_multiply", "$1(1)"},
|
||||||
|
{"init_logical_and", "true"},
|
||||||
|
{"init_logical_or", "false"},
|
||||||
|
{"init_logical_xor", "false"},
|
||||||
|
{"init_bitwise_and", "$1(-1)"},
|
||||||
|
{"init_bitwise_or", "$1(0)"},
|
||||||
|
{"init_bitwise_xor", "$1(0)"},
|
||||||
|
{"init_mean", "$1(0)"},
|
||||||
|
};
|
||||||
|
|
||||||
|
string ret;
|
||||||
|
if (both_map.count(args.at(0)))
|
||||||
|
ret = both_map[args.at(0)];
|
||||||
|
else if (use_cuda)
|
||||||
|
ret = cuda_map[args.at(0)];
|
||||||
|
else
|
||||||
|
ret = cpu_map[args.at(0)];
|
||||||
|
return format(ret, args);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
static int _ = registe_op_type(new CommonOpType());
|
||||||
|
|
||||||
|
}
|
|
@ -33,4 +33,12 @@ OpInfo get_op_info(const string& name) {
|
||||||
return op_info_map.at(op_file_name);
|
return op_info_map.at(op_file_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
vector<OpByType*> op_types;
|
||||||
|
|
||||||
|
int registe_op_type(OpByType* op_type) {
|
||||||
|
op_types.push_back(op_type);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
} // jittor
|
} // jittor
|
|
@ -32,4 +32,12 @@ void op_registe(const OpInfo& op_info);
|
||||||
bool has_op(const string& name);
|
bool has_op(const string& name);
|
||||||
OpInfo get_op_info(const string& name);
|
OpInfo get_op_info(const string& name);
|
||||||
|
|
||||||
|
struct OpByType {
|
||||||
|
unordered_set<string> types;
|
||||||
|
virtual string expand_op(const vector<string>& args) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
extern vector<OpByType*> op_types;
|
||||||
|
int registe_op_type(OpByType*);
|
||||||
|
|
||||||
} // jittor
|
} // jittor
|
|
@ -8,7 +8,6 @@
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include "var.h"
|
#include "var.h"
|
||||||
#include "ops/reduce_op.h"
|
#include "ops/reduce_op.h"
|
||||||
#include "ops/binary_op_defs.h"
|
|
||||||
#include "ops/op_register.h"
|
#include "ops/op_register.h"
|
||||||
#include "executor.h"
|
#include "executor.h"
|
||||||
|
|
||||||
|
@ -364,14 +363,14 @@ void ReduceOp::jit_run() {
|
||||||
Ty rcount = Ty(y->num) / Ty(x->num);
|
Ty rcount = Ty(y->num) / Ty(x->num);
|
||||||
@for(d, 0, DIM,@if(REDUCE>>d&1,, for (index_t xi@d=0; xi@d < xshape@d; xi@d++))) {
|
@for(d, 0, DIM,@if(REDUCE>>d&1,, for (index_t xi@d=0; xi@d < xshape@d; xi@d++))) {
|
||||||
auto yid = 0 @for(d, 0, DIM,@if(REDUCE>>d&1,, + xi@d * ystride@d));
|
auto yid = 0 @for(d, 0, DIM,@if(REDUCE>>d&1,, + xi@d * ystride@d));
|
||||||
yp[yid] = @expand_macro(init_@OP, Ty);
|
yp[yid] = @expand_op(init_@OP, @Ty);
|
||||||
}
|
}
|
||||||
|
|
||||||
@for(d, 0, DIM,@if(REDUCE>>d&1,, for (index_t xi@d=0; xi@d < xshape@d; xi@d++))) {
|
@for(d, 0, DIM,@if(REDUCE>>d&1,, for (index_t xi@d=0; xi@d < xshape@d; xi@d++))) {
|
||||||
@for(d, 0, DIM,@if(REDUCE>>d&1, for (index_t xi@d=0; xi@d < xshape@d; xi@d++),)) {
|
@for(d, 0, DIM,@if(REDUCE>>d&1, for (index_t xi@d=0; xi@d < xshape@d; xi@d++),)) {
|
||||||
auto yid = 0 @for(d, 0, DIM,@if(REDUCE>>d&1,, + xi@d * ystride@d));
|
auto yid = 0 @for(d, 0, DIM,@if(REDUCE>>d&1,, + xi@d * ystride@d));
|
||||||
auto xid = 0 @for(d, 0, DIM, + xi@d * xstride@d);
|
auto xid = 0 @for(d, 0, DIM, + xi@d * xstride@d);
|
||||||
yp[yid] = @expand_macro(@OP, Ty, yp[yid], xp[xid]);
|
yp[yid] = @expand_op(@OP, @Ty, yp[yid], @Ty, xp[xid], @Tx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
(void)count, (void)rcount, (void)yshape0, (void)ystride0;
|
(void)count, (void)rcount, (void)yshape0, (void)ystride0;
|
||||||
|
|
|
@ -8,7 +8,6 @@
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include "var.h"
|
#include "var.h"
|
||||||
#include "ops/reindex_reduce_op.h"
|
#include "ops/reindex_reduce_op.h"
|
||||||
#include "ops/binary_op_defs.h"
|
|
||||||
#include "ops/op_register.h"
|
#include "ops/op_register.h"
|
||||||
|
|
||||||
namespace jittor {
|
namespace jittor {
|
||||||
|
@ -112,7 +111,7 @@ void ReindexReduceOp::jit_run() {
|
||||||
|
|
||||||
@for(d, 0, XDIM, for (index_t i@d=0; i@d < xshape@d; i@d++)) {
|
@for(d, 0, XDIM, for (index_t i@d=0; i@d < xshape@d; i@d++)) {
|
||||||
auto xid = @for(d, 0, XDIM, + i@d * xstride@d);
|
auto xid = @for(d, 0, XDIM, + i@d * xstride@d);
|
||||||
xp[xid] = @expand_macro(init_@OP, Tx);
|
xp[xid] = @expand_op(init_@OP, @Tx);
|
||||||
}
|
}
|
||||||
// generate d-for loop
|
// generate d-for loop
|
||||||
@for(d, 0, YDIM, for (index_t i@d=0; i@d < yshape@d; i@d++)) {
|
@for(d, 0, YDIM, for (index_t i@d=0; i@d < yshape@d; i@d++)) {
|
||||||
|
@ -121,7 +120,7 @@ void ReindexReduceOp::jit_run() {
|
||||||
auto xid = @for(d, 0, XDIM, + xid@d * xstride@d);
|
auto xid = @for(d, 0, XDIM, + xid@d * xstride@d);
|
||||||
bool check_overflow = 0 @for(d, 0, XDIM, || xid@d<0 || xid@d>=xshape@d) @for(d, 0, OSIZE, || (@expand_macro(OFD@d)));
|
bool check_overflow = 0 @for(d, 0, XDIM, || xid@d<0 || xid@d>=xshape@d) @for(d, 0, OSIZE, || (@expand_macro(OFD@d)));
|
||||||
if (!check_overflow)
|
if (!check_overflow)
|
||||||
xp[xid] = @expand_macro(@OP, Tx, xp[xid], yp[yid]);
|
xp[xid] = @expand_op(@OP, @Tx, xp[xid], @Tx, yp[yid], @Tx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif // JIT
|
#endif // JIT
|
||||||
|
|
|
@ -9,7 +9,6 @@
|
||||||
#include "ops/setitem_op.h"
|
#include "ops/setitem_op.h"
|
||||||
#include "ops/getitem_op.h"
|
#include "ops/getitem_op.h"
|
||||||
#ifdef JIT
|
#ifdef JIT
|
||||||
#include "ops/binary_op_defs.h"
|
|
||||||
#ifdef JIT_cuda
|
#ifdef JIT_cuda
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
#include "helper_cuda.h"
|
#include "helper_cuda.h"
|
||||||
|
@ -340,12 +339,12 @@ void SetitemOp::jit_run() {
|
||||||
@if(@is_def(JIT_cpu),
|
@if(@is_def(JIT_cpu),
|
||||||
@if(@strcmp(@OP,void)==0,
|
@if(@strcmp(@OP,void)==0,
|
||||||
op[iid] = (Ti)dp[did],
|
op[iid] = (Ti)dp[did],
|
||||||
op[iid] = @expand_macro(@OP, Ti, op[iid], dp[did])
|
op[iid] = @expand_op(@OP, @Ti, op[iid], @Ti, dp[did], @Td)
|
||||||
);
|
);
|
||||||
,
|
,
|
||||||
@if(@strcmp(@OP,void)==0, op[iid] = (Ti)dp[did],
|
@if(@strcmp(@OP,void)==0, op[iid] = (Ti)dp[did],
|
||||||
@if(@strcmp(@OP,add)==0, atomicAdd(&op[iid], (Ti)dp[did]),
|
@if(@strcmp(@OP,add)==0, atomicAdd(&op[iid], (Ti)dp[did]),
|
||||||
op[iid] = @expand_macro(@OP, Ti, op[iid], dp[did])
|
op[iid] = @expand_op(@OP, @Ti, op[iid], @Ti, dp[did], @Td)
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
)
|
)
|
||||||
|
|
|
@ -8,7 +8,6 @@
|
||||||
#include "misc/cpu_math.h"
|
#include "misc/cpu_math.h"
|
||||||
#include "var.h"
|
#include "var.h"
|
||||||
#include "ops/unary_op.h"
|
#include "ops/unary_op.h"
|
||||||
#include "ops/unary_op_defs.h"
|
|
||||||
#include "ops/op_register.h"
|
#include "ops/op_register.h"
|
||||||
|
|
||||||
namespace jittor {
|
namespace jittor {
|
||||||
|
@ -688,7 +687,7 @@ void UnaryOp::jit_run() {
|
||||||
auto* __restrict__ yp = y->ptr<Ty>();
|
auto* __restrict__ yp = y->ptr<Ty>();
|
||||||
index_t num = y->num;
|
index_t num = y->num;
|
||||||
for (index_t i=0; i<num; i++)
|
for (index_t i=0; i<num; i++)
|
||||||
yp[i] = @expand_macro(@OP, Ty, xp[i]);
|
yp[i] = @expand_op(@OP, @Ty, xp[i], @Tx);
|
||||||
}
|
}
|
||||||
#endif // JIT
|
#endif // JIT
|
||||||
|
|
||||||
|
|
|
@ -1,84 +0,0 @@
|
||||||
// ***************************************************************
|
|
||||||
// Copyright (c) 2021 Jittor. All Rights Reserved.
|
|
||||||
// Maintainers: Dun Liang <randonlang@gmail.com>.
|
|
||||||
// This file is subject to the terms and conditions defined in
|
|
||||||
// file 'LICENSE.txt', which is part of this source code package.
|
|
||||||
// ***************************************************************
|
|
||||||
#pragma once
|
|
||||||
#include "common.h"
|
|
||||||
|
|
||||||
namespace jittor {
|
|
||||||
|
|
||||||
#define logical_not(T,x) (!(x))
|
|
||||||
#define bitwise_not(T,x) (~(x))
|
|
||||||
#define negative(T,x) (-(x))
|
|
||||||
#ifdef JIT_cuda
|
|
||||||
// TODO: add float64 version
|
|
||||||
#define abs(T,x) ::abs(x)
|
|
||||||
#define log(T,x) ::logf((T)(x))
|
|
||||||
#define exp(T,x) ::expf((T)(x))
|
|
||||||
#define sqrt(T,x) ::sqrtf((T)(x))
|
|
||||||
#define round(T,x) ((T) ::roundf((x)))
|
|
||||||
#define floor(T,x) ((T) ::floorf((x)))
|
|
||||||
#define ceil(T,x) ((T) ::ceilf((x)))
|
|
||||||
#define round_int(T,x) ((T) ::roundf((x)))
|
|
||||||
#define floor_int(T,x) ((T) ::floorf((x)))
|
|
||||||
#define ceil_int(T,x) ((T) ::ceilf((x)))
|
|
||||||
|
|
||||||
#define sin(T,x) ((T) ::sinf((x)))
|
|
||||||
#define asin(T,x) ((T) ::asinf((x)))
|
|
||||||
#define sinh(T,x) ((T) ::sinhf((x)))
|
|
||||||
#define asinh(T,x) ((T) ::asinhf((x)))
|
|
||||||
|
|
||||||
#define cos(T,x) ((T) ::cosf((x)))
|
|
||||||
#define acos(T,x) ((T) ::acosf((x)))
|
|
||||||
#define cosh(T,x) ((T) ::coshf((x)))
|
|
||||||
#define acosh(T,x) ((T) ::acoshf((x)))
|
|
||||||
|
|
||||||
#define tan(T,x) ((T) ::tanf((x)))
|
|
||||||
#define atan(T,x) ((T) ::atanf((x)))
|
|
||||||
#define tanh(T,x) ((T) ::tanhf((x)))
|
|
||||||
#define atanh(T,x) ((T) ::atanhf((x)))
|
|
||||||
|
|
||||||
#define sigmoid(T,x) ((T) (1.0f/(1.0f+::expf((::min(T(-(x)), T(@if(@strcmp(@T,float32)==0,30,300))))))))
|
|
||||||
|
|
||||||
#define erf(T,x) ((T) ::erff((x)))
|
|
||||||
#define erfinv(T,x) ((T) ::erfinvf((T)(x)))
|
|
||||||
|
|
||||||
#else
|
|
||||||
#define abs(T,x) std::abs(x)
|
|
||||||
#define log(T,x) std::log((T)(x))
|
|
||||||
#define exp(T,x) std::exp((T)(x))
|
|
||||||
#define sqrt(T,x) std::sqrt((T)(x))
|
|
||||||
#define round(T,x) ((T)std::round((x)))
|
|
||||||
#define floor(T,x) ((T)std::floor((x)))
|
|
||||||
#define ceil(T,x) ((T)std::ceil((x)))
|
|
||||||
#define round_int(T,x) ((T)std::round((x)))
|
|
||||||
#define floor_int(T,x) ((T)std::floor((x)))
|
|
||||||
#define ceil_int(T,x) ((T)std::ceil((x)))
|
|
||||||
|
|
||||||
#define sin(T,x) ((T) std::sin((x)))
|
|
||||||
#define asin(T,x) ((T) std::asin((x)))
|
|
||||||
#define sinh(T,x) ((T) std::sinh((x)))
|
|
||||||
#define asinh(T,x) ((T) std::asinh((x)))
|
|
||||||
|
|
||||||
#define cos(T,x) ((T) std::cos((x)))
|
|
||||||
#define acos(T,x) ((T) std::acos((x)))
|
|
||||||
#define cosh(T,x) ((T) std::cosh((x)))
|
|
||||||
#define acosh(T,x) ((T) std::acosh((x)))
|
|
||||||
|
|
||||||
#define tan(T,x) ((T) std::tan((x)))
|
|
||||||
#define atan(T,x) ((T) std::atan((x)))
|
|
||||||
#define tanh(T,x) ((T) std::tanh((x)))
|
|
||||||
#define atanh(T,x) ((T) std::atanh((x)))
|
|
||||||
|
|
||||||
#define sigmoid(T,x) ((T) (1.0f/(1.0f+std::exp(std::min(T(-(x)), T(@if(@strcmp(@T,float32)==0,30,300)))))))
|
|
||||||
|
|
||||||
#define erf(T,x) ((T) std::erf((x)))
|
|
||||||
#define erfinv(T,x) (jittor::_erfinv(x))
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#define cast(T,x) ((T)(x))
|
|
||||||
|
|
||||||
} // jittor
|
|
|
@ -47,4 +47,18 @@ string strip(const string& s) {
|
||||||
return s.substr(i,j-i);
|
return s.substr(i,j-i);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
string format(const string& s, const vector<string>& v) {
|
||||||
|
string ss;
|
||||||
|
for (int i=0; i<s.size(); i++) {
|
||||||
|
if (s[i] == '$') {
|
||||||
|
int j = s[i+1] - '0';
|
||||||
|
ss += v.at(j);
|
||||||
|
i ++;
|
||||||
|
continue;
|
||||||
|
} else
|
||||||
|
ss += s[i];
|
||||||
|
}
|
||||||
|
return ss;
|
||||||
|
}
|
||||||
|
|
||||||
} // jittor
|
} // jittor
|
|
@ -27,4 +27,6 @@ vector<string> split(const string& s, const string& sep, int max_split=0);
|
||||||
|
|
||||||
string strip(const string& s);
|
string strip(const string& s);
|
||||||
|
|
||||||
|
string format(const string& s, const vector<string>& v);
|
||||||
|
|
||||||
} // jittor
|
} // jittor
|
|
@ -111,17 +111,6 @@ class TestOpCompiler(unittest.TestCase):
|
||||||
check("@{a^b == 7}", "2")
|
check("@{a^b == 7}", "2")
|
||||||
check("@{(a^b) == 7}", "1")
|
check("@{(a^b) == 7}", "1")
|
||||||
check("@{b<<a == 5*4}", "1")
|
check("@{b<<a == 5*4}", "1")
|
||||||
check('''#include "ops/binary_op_defs.h"
|
|
||||||
#define OP1(a, b) a+b
|
|
||||||
OP1
|
|
||||||
@expand_macro(OP1,1,2)
|
|
||||||
@expand_macro(maximum, T, 1, 2)
|
|
||||||
@expand_macro(@OP,T,1,2)''',
|
|
||||||
''' #define OP1(a, b) a+b
|
|
||||||
OP1
|
|
||||||
1+2
|
|
||||||
std::max(T(1), T(2))
|
|
||||||
((1)+T(2)*(T(rcount)))''')
|
|
||||||
expect_error(lambda: jit_precompile(vars, "@{a"))
|
expect_error(lambda: jit_precompile(vars, "@{a"))
|
||||||
expect_error(lambda: jit_precompile(vars, "@for(a"))
|
expect_error(lambda: jit_precompile(vars, "@for(a"))
|
||||||
expect_error(lambda: jit_precompile(vars, "@for(i,l,r)"))
|
expect_error(lambda: jit_precompile(vars, "@for(i,l,r)"))
|
||||||
|
|
Loading…
Reference in New Issue