mirror of https://github.com/Jittor/Jittor
fp16 support, opbytype interface
This commit is contained in:
parent
81b847e6f0
commit
bcb57086c3
|
@ -1235,6 +1235,7 @@ files4 = [ f[len(jittor_path)+1:] for f in files4 ]
|
|||
# files4 = run_cmd('find -L src | grep '+grep_args, jittor_path).splitlines()
|
||||
at_beginning = [
|
||||
"src/ops/op_utils.cc",
|
||||
"src/ops/op_register.cc",
|
||||
"src/event_queue.cc",
|
||||
"src/mem/allocator/sfrl_allocator.cc",
|
||||
"src/mem/allocator.cc",
|
||||
|
|
|
@ -60,14 +60,14 @@ class Pool(Module):
|
|||
'''
|
||||
if not self.return_indices:
|
||||
forward_body += f'''
|
||||
@out(i0, i1, i2, i3) = init_{self.op}(out_type);
|
||||
@out(i0, i1, i2, i3) = @expand_op(init_{self.op}, @out_type);
|
||||
for (int p = k2; p < k2_; ++p)
|
||||
for (int q = k3; q < k3_; ++q)
|
||||
@out(i0, i1, i2, i3) = {self.op}(out_type, @out(i0, i1, i2, i3), @in0(i0, i1, p, q));
|
||||
@out(i0, i1, i2, i3) = @expand_op({self.op}, @out_type, @out(i0, i1, i2, i3), @out_type, @in0(i0, i1, p, q), @in0_type);
|
||||
'''
|
||||
else:
|
||||
forward_body += f'''
|
||||
auto out_value = init_{self.op}(out_type);
|
||||
auto out_value = @expand_op(init_{self.op}, @out_type);
|
||||
int out_index = -1;
|
||||
for (int p = k2; p < k2_; ++p)
|
||||
for (int q = k3; q < k3_; ++q)
|
||||
|
@ -105,7 +105,6 @@ class Pool(Module):
|
|||
return_dtypes = x.dtype
|
||||
out = jt.code(return_shapes, return_dtypes, [x],
|
||||
cuda_header="""
|
||||
#include <ops/binary_op_defs.h>
|
||||
#include <misc/cuda_limits.h>
|
||||
""",
|
||||
cuda_src=f'''
|
||||
|
@ -153,7 +152,7 @@ class Pool(Module):
|
|||
dim3 s2_(tx, ty);
|
||||
kernel3<<<s1_, s2_>>>(@ARGS);
|
||||
'''],
|
||||
cpu_header='#include <ops/binary_op_defs.h>',
|
||||
cpu_header='',
|
||||
cpu_src=f'''
|
||||
using namespace std;
|
||||
for (int i0=0; i0<out_shape0; i0++)
|
||||
|
@ -242,15 +241,15 @@ class Pool3d(Module):
|
|||
'''
|
||||
if not self.return_indices:
|
||||
forward_body += f'''
|
||||
@out(i0, i1, i2, i3, i4) = init_{self.op}(out_type);
|
||||
@out(i0, i1, i2, i3, i4) = @expand_op(init_{self.op}, @out_type);
|
||||
for (int p = k2; p < k2_; ++p)
|
||||
for (int q = k3; q < k3_; ++q)
|
||||
for (int r = k4; r < k4_; ++r)
|
||||
@out(i0, i1, i2, i3, i4) = {self.op}(out_type, @out(i0, i1, i2, i3, i4), @in0(i0, i1, p, q, r));
|
||||
@out(i0, i1, i2, i3, i4) = @expand_op({self.op}, @out_type, @out(i0, i1, i2, i3, i4), @out_type, @in0(i0, i1, p, q, r), @in0_type);
|
||||
'''
|
||||
else:
|
||||
forward_body += f'''
|
||||
auto out_value = init_{self.op}(out_type);
|
||||
auto out_value = @expand_op(init_{self.op}, @out_type);
|
||||
int out_index = -1;
|
||||
for (int p = k2; p < k2_; ++p)
|
||||
for (int q = k3; q < k3_; ++q)
|
||||
|
@ -293,7 +292,6 @@ class Pool3d(Module):
|
|||
return_dtypes = x.dtype
|
||||
out = jt.code(return_shapes, return_dtypes, [x],
|
||||
cuda_header="""
|
||||
#include <ops/binary_op_defs.h>
|
||||
#include <misc/cuda_limits.h>
|
||||
""",
|
||||
cuda_src=f'''
|
||||
|
@ -349,7 +347,7 @@ class Pool3d(Module):
|
|||
dim3 s2(tx, ty, tz);
|
||||
kernel3<<<s1, s2>>>(@ARGS);
|
||||
'''],
|
||||
cpu_header='#include <ops/binary_op_defs.h>',
|
||||
cpu_header='',
|
||||
cpu_src=f'''
|
||||
using namespace std;
|
||||
for (int i0=0; i0<out_shape0; i0++)
|
||||
|
|
|
@ -223,6 +223,16 @@ void load_macros(const string& src, unordered_map<string,string>& macros) {
|
|||
}
|
||||
}
|
||||
|
||||
string expand_op_search(const vector<string>& args) {
|
||||
for (auto op_type : op_types) {
|
||||
string ret = op_type->expand_op(args);
|
||||
if (ret.size())
|
||||
return ret;
|
||||
}
|
||||
LOGf << "No expand op pattern found for args:" << args;
|
||||
return "";
|
||||
}
|
||||
|
||||
void expand_macro(const string& macro, const vector<string>& args, string& new_src) {
|
||||
LOGvvvv << "expand_macro" << macro << "args:" << args;
|
||||
if (macro.size() == 0 || macro[0] != '<') {
|
||||
|
@ -434,6 +444,7 @@ string precompile(unordered_map<string,string> defs, string src, unordered_map<s
|
|||
vector<string> args;
|
||||
size_t l = k+1;
|
||||
if (expr == "for" || expr == "if" || expr == "expand_macro" ||
|
||||
expr == "expand_op" ||
|
||||
expr == "is_def" || expr == "python" ||
|
||||
(k<src.size() && src[k]=='(')) {
|
||||
ASSERT(src[k] == '(');
|
||||
|
@ -555,6 +566,18 @@ string precompile(unordered_map<string,string> defs, string src, unordered_map<s
|
|||
i = l-1;
|
||||
continue;
|
||||
} else
|
||||
if (expr == "expand_op") {
|
||||
// syntax: @expand_op(args)
|
||||
for (auto& arg : args) {
|
||||
uint p=0;
|
||||
while (p<arg.size() && arg[p] == ' ') p++;
|
||||
arg = precompile(defs, arg.substr(p), macros);
|
||||
}
|
||||
string ns = expand_op_search(args);
|
||||
new_src += precompile(defs, ns, macros);
|
||||
i = l-1;
|
||||
continue;
|
||||
} else
|
||||
if (expr == "define") {
|
||||
// syntax: @define(macro, value)
|
||||
// ij k l
|
||||
|
@ -846,6 +869,9 @@ string OpCompiler::__get_fused_src(
|
|||
};
|
||||
auto not_change = [&](const string& s) -> bool {
|
||||
if (unchanged.count(s)) return true;
|
||||
for (auto op_type : op_types)
|
||||
if (op_type->types.count(s))
|
||||
return true;
|
||||
return (s.find("::") != string::npos) || (s.find("LOG") != string::npos);
|
||||
};
|
||||
// regex find XxxXxxOp::jit_run
|
||||
|
|
|
@ -8,7 +8,6 @@
|
|||
#include "var.h"
|
||||
#include "ops/binary_op.h"
|
||||
#include "ops/broadcast_to_op.h"
|
||||
#include "ops/binary_op_defs.h"
|
||||
#include "ops/op_register.h"
|
||||
|
||||
namespace jittor {
|
||||
|
@ -554,7 +553,7 @@ void BinaryOp::jit_run() {
|
|||
auto* __restrict__ zp = z->ptr<Tz>();
|
||||
index_t num = z->num;
|
||||
for (index_t i=0; i<num; i++)
|
||||
zp[i] = @expand_macro(@OP, Tz, xp[i], yp[i]);
|
||||
zp[i] = @expand_op(@OP, @Tz, xp[i], @Tx, yp[i], @Ty);
|
||||
}
|
||||
#endif // JIT
|
||||
|
||||
|
|
|
@ -1,61 +0,0 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
// Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "common.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
#ifdef JIT_cuda
|
||||
#define pow(T,a,b) ::pow(a,b)
|
||||
#define maximum(T,a,b) ::max(T(a), T(b))
|
||||
#define minimum(T,a,b) ::min(T(a), T(b))
|
||||
#define mod(T,a,b) @if(@strcmp(@T,float32)==0,(a-::floorf((a)/(b))*(b)),@if(@strcmp(@Tx,float64)==0,(a-::floor((a)/(b))*(b)),(a%b)))
|
||||
#else // JIT_cpu
|
||||
#define pow(T,a,b) std::pow(a,b)
|
||||
#define maximum(T,a,b) std::max(T(a), T(b))
|
||||
#define minimum(T,a,b) std::min(T(a), T(b))
|
||||
#define mod(T,a,b) @if(@strcmp(@T,float32)==0,(a-std::floor((a)/(b))*(b)),@if(@strcmp(@Tx,float64)==0,(a-std::floor((a)/(b))*(b)),(a%b)))
|
||||
#endif
|
||||
#define add(T,a,b) ((a)+(b))
|
||||
#define subtract(T,a,b) ((a)-(b))
|
||||
#define multiply(T,a,b) ((a)*(b))
|
||||
#define divide(T,a,b) (T((T(a))/(T(b))))
|
||||
#define floor_divide(T,a,b) (T((T(a))/(T(b))))
|
||||
#define less(T,a,b) ((a)<(b))
|
||||
#define less_equal(T,a,b) ((a)<=(b))
|
||||
#define greater(T,a,b) ((a)>(b))
|
||||
#define greater_equal(T,a,b) ((a)>=(b))
|
||||
#define equal(T,a,b) ((a)==(b))
|
||||
#define not_equal(T,a,b) ((a)!=(b))
|
||||
#define left_shift(T,a,b) ((a)<<(b))
|
||||
#define right_shift(T,a,b) ((a)>>(b))
|
||||
#define logical_and(T,a,b) ((a)&&(b))
|
||||
#define logical_or(T,a,b) ((a)||(b))
|
||||
#define logical_xor(T,a,b) ((bool(a))!=(bool(b)))
|
||||
#define bitwise_and(T,a,b) ((a)&(b))
|
||||
#define bitwise_or(T,a,b) ((a)|(b))
|
||||
#define bitwise_xor(T,a,b) ((a)^(b))
|
||||
#define mean(T,a,b) ((a)+T(b)*(T(rcount)))
|
||||
|
||||
#ifdef JIT_cuda
|
||||
#define init_maximum(T) ::numeric_min<T>()
|
||||
#define init_minimum(T) ::numeric_max<T>()
|
||||
#else
|
||||
#define init_maximum(T) std::numeric_limits<T>::lowest()
|
||||
#define init_minimum(T) std::numeric_limits<T>::max()
|
||||
#endif
|
||||
#define init_add(T) T(0)
|
||||
#define init_multiply(T) T(1)
|
||||
#define init_logical_and(T) true
|
||||
#define init_logical_or(T) false
|
||||
#define init_logical_xor(T) false
|
||||
#define init_bitwise_and(T) T(-1)
|
||||
#define init_bitwise_or(T) T(0)
|
||||
#define init_bitwise_xor(T) T(0)
|
||||
#define init_mean(T) T(0)
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,159 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
// Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "common.h"
|
||||
#include "utils/str_utils.h"
|
||||
#include "ops/op_register.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
extern int use_cuda;
|
||||
|
||||
struct CommonOpType : OpByType {
|
||||
CommonOpType() {
|
||||
types = {
|
||||
"bool",
|
||||
"int8",
|
||||
"int16",
|
||||
"int32",
|
||||
"int64",
|
||||
"uint8",
|
||||
"uint16",
|
||||
"uint32",
|
||||
"uint64",
|
||||
"float32",
|
||||
"float64",
|
||||
};
|
||||
}
|
||||
|
||||
string expand_op(const vector<string>& args) {
|
||||
for (int i=1; i<args.size(); i+=2) {
|
||||
if (!types.count(args[i]))
|
||||
return "";
|
||||
}
|
||||
static unordered_map<string,string> cuda_map = {
|
||||
{"logical_not", "(!($2))"},
|
||||
{"bitwise_not", "(~($2))"},
|
||||
{"negative", "(-($2))"},
|
||||
{"abs", "::abs($2)"},
|
||||
{"log", "::logf(($1)($2))"},
|
||||
{"exp", "::expf(($1)($2))"},
|
||||
{"sqrt", "::sqrtf(($1)($2))"},
|
||||
{"round", "(($1) ::roundf(($2)))"},
|
||||
{"floor", "(($1) ::floorf(($2)))"},
|
||||
{"ceil", "(($1) ::ceilf(($2)))"},
|
||||
{"round_int", "(($1) ::roundf(($2)))"},
|
||||
{"floor_int", "(($1) ::floorf(($2)))"},
|
||||
{"ceil_int", "(($1) ::ceilf(($2)))"},
|
||||
{"sin", "(($1) ::sinf(($2)))"},
|
||||
{"asin", "(($1) ::asinf(($2)))"},
|
||||
{"sinh", "(($1) ::sinhf(($2)))"},
|
||||
{"asinh", "(($1) ::asinhf(($2)))"},
|
||||
{"cos", "(($1) ::cosf(($2)))"},
|
||||
{"acos", "(($1) ::acosf(($2)))"},
|
||||
{"cosh", "(($1) ::coshf(($2)))"},
|
||||
{"acosh", "(($1) ::acoshf(($2)))"},
|
||||
{"tan", "(($1) ::tanf(($2)))"},
|
||||
{"atan", "(($1) ::atanf(($2)))"},
|
||||
{"tanh", "(($1) ::tanhf(($2)))"},
|
||||
{"atanh", "(($1) ::atanhf(($2)))"},
|
||||
{"sigmoid", "(($1) (1.0f/(1.0f+::expf((::min($1(-($2)), $1(@if(@strcmp($1,float32)==0,30,300))))))))"},
|
||||
{"erf", "(($1) ::erff(($2)))"},
|
||||
{"erfinv", "(($1) ::erfinvf(($1)($2)))"},
|
||||
{"cast", "(($1)($2))"},
|
||||
{"pow", "::pow(($2),($4))"},
|
||||
{"maximum", "::max($1($2), $1($4))"},
|
||||
{"minimum", "::min($1($2), $1($4))"},
|
||||
{"mod", "@if(@strcmp($1,float32)==0,(($2)-::floorf(($2)/($4))*($4)),@if(@strcmp(@Tx,float64)==0,(($2)-::floor(($2)/($4))*($4)),(($2)%($4))))"},
|
||||
{"init_maximum", "::numeric_min<$1>()"},
|
||||
{"init_minimum", "::numeric_max<$1>()"},
|
||||
};
|
||||
|
||||
static unordered_map<string,string> cpu_map = {
|
||||
{"logical_not", "(!($2))"},
|
||||
{"bitwise_not", "(~($2))"},
|
||||
{"negative", "(-($2))"},
|
||||
{"abs", "std::abs($2)"},
|
||||
{"log", "std::log(($1)($2))"},
|
||||
{"exp", "std::exp(($1)($2))"},
|
||||
{"sqrt", "std::sqrt(($1)($2))"},
|
||||
{"round", "(($1)std::round(($2)))"},
|
||||
{"floor", "(($1)std::floor(($2)))"},
|
||||
{"ceil", "(($1)std::ceil(($2)))"},
|
||||
{"round_int", "(($1)std::round(($2)))"},
|
||||
{"floor_int", "(($1)std::floor(($2)))"},
|
||||
{"ceil_int", "(($1)std::ceil(($2)))"},
|
||||
{"sin", "(($1) std::sin(($2)))"},
|
||||
{"asin", "(($1) std::asin(($2)))"},
|
||||
{"sinh", "(($1) std::sinh(($2)))"},
|
||||
{"asinh", "(($1) std::asinh(($2)))"},
|
||||
{"cos", "(($1) std::cos(($2)))"},
|
||||
{"acos", "(($1) std::acos(($2)))"},
|
||||
{"cosh", "(($1) std::cosh(($2)))"},
|
||||
{"acosh", "(($1) std::acosh(($2)))"},
|
||||
{"tan", "(($1) std::tan(($2)))"},
|
||||
{"atan", "(($1) std::atan(($2)))"},
|
||||
{"tanh", "(($1) std::tanh(($2)))"},
|
||||
{"atanh", "(($1) std::atanh(($2)))"},
|
||||
{"sigmoid", "(($1) (1.0f/(1.0f+std::exp(std::min($1(-($2)), $1(@if(@strcmp($1,float32)==0,30,300)))))))"},
|
||||
{"erf", "(($1) std::erf(($2)))"},
|
||||
{"erfinv", "(jittor::_erfinv($2))"},
|
||||
{"cast", "(($1)($2))"},
|
||||
{"pow", "std::pow(($2),($4))"},
|
||||
{"maximum", "std::max($1($2), $1($4))"},
|
||||
{"minimum", "std::min($1($2), $1($4))"},
|
||||
{"mod", "@if(@strcmp($1,float32)==0,(($2)-std::floor(($2)/($4))*($4)),@if(@strcmp(@Tx,float64)==0,(($2)-std::floor(($2)/($4))*($4)),(($2)%($4))))"},
|
||||
{"init_maximum", "std::numeric_limits<$1>::lowest()"},
|
||||
{"init_minimum", "std::numeric_limits<$1>::max()"},
|
||||
};
|
||||
|
||||
static unordered_map<string,string> both_map {
|
||||
{"add", "(($2)+($4))"},
|
||||
{"subtract", "(($2)-($4))"},
|
||||
{"multiply", "(($2)*($4))"},
|
||||
{"divide", "($1(($1($2))/($1($4))))"},
|
||||
{"floor_divide", "($1(($1($2))/($1($4))))"},
|
||||
{"less", "(($2)<($4))"},
|
||||
{"less_equal", "(($2)<=($4))"},
|
||||
{"greater", "(($2)>($4))"},
|
||||
{"greater_equal", "(($2)>=($4))"},
|
||||
{"equal", "(($2)==($4))"},
|
||||
{"not_equal", "(($2)!=($4))"},
|
||||
{"left_shift", "(($2)<<($4))"},
|
||||
{"right_shift", "(($2)>>($4))"},
|
||||
{"logical_and", "(($2)&&($4))"},
|
||||
{"logical_or", "(($2)||($4))"},
|
||||
{"logical_xor", "((bool($2))!=(bool($4)))"},
|
||||
{"bitwise_and", "(($2)&($4))"},
|
||||
{"bitwise_or", "(($2)|($4))"},
|
||||
{"bitwise_xor", "(($2)^($4))"},
|
||||
{"mean", "(($2)+$1($4)*($1(rcount)))"},
|
||||
{"init_add", "$1(0)"},
|
||||
{"init_multiply", "$1(1)"},
|
||||
{"init_logical_and", "true"},
|
||||
{"init_logical_or", "false"},
|
||||
{"init_logical_xor", "false"},
|
||||
{"init_bitwise_and", "$1(-1)"},
|
||||
{"init_bitwise_or", "$1(0)"},
|
||||
{"init_bitwise_xor", "$1(0)"},
|
||||
{"init_mean", "$1(0)"},
|
||||
};
|
||||
|
||||
string ret;
|
||||
if (both_map.count(args.at(0)))
|
||||
ret = both_map[args.at(0)];
|
||||
else if (use_cuda)
|
||||
ret = cuda_map[args.at(0)];
|
||||
else
|
||||
ret = cpu_map[args.at(0)];
|
||||
return format(ret, args);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
static int _ = registe_op_type(new CommonOpType());
|
||||
|
||||
}
|
|
@ -33,4 +33,12 @@ OpInfo get_op_info(const string& name) {
|
|||
return op_info_map.at(op_file_name);
|
||||
}
|
||||
|
||||
vector<OpByType*> op_types;
|
||||
|
||||
int registe_op_type(OpByType* op_type) {
|
||||
op_types.push_back(op_type);
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
} // jittor
|
|
@ -32,4 +32,12 @@ void op_registe(const OpInfo& op_info);
|
|||
bool has_op(const string& name);
|
||||
OpInfo get_op_info(const string& name);
|
||||
|
||||
struct OpByType {
|
||||
unordered_set<string> types;
|
||||
virtual string expand_op(const vector<string>& args) = 0;
|
||||
};
|
||||
|
||||
extern vector<OpByType*> op_types;
|
||||
int registe_op_type(OpByType*);
|
||||
|
||||
} // jittor
|
|
@ -8,7 +8,6 @@
|
|||
#include <limits>
|
||||
#include "var.h"
|
||||
#include "ops/reduce_op.h"
|
||||
#include "ops/binary_op_defs.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "executor.h"
|
||||
|
||||
|
@ -364,14 +363,14 @@ void ReduceOp::jit_run() {
|
|||
Ty rcount = Ty(y->num) / Ty(x->num);
|
||||
@for(d, 0, DIM,@if(REDUCE>>d&1,, for (index_t xi@d=0; xi@d < xshape@d; xi@d++))) {
|
||||
auto yid = 0 @for(d, 0, DIM,@if(REDUCE>>d&1,, + xi@d * ystride@d));
|
||||
yp[yid] = @expand_macro(init_@OP, Ty);
|
||||
yp[yid] = @expand_op(init_@OP, @Ty);
|
||||
}
|
||||
|
||||
@for(d, 0, DIM,@if(REDUCE>>d&1,, for (index_t xi@d=0; xi@d < xshape@d; xi@d++))) {
|
||||
@for(d, 0, DIM,@if(REDUCE>>d&1, for (index_t xi@d=0; xi@d < xshape@d; xi@d++),)) {
|
||||
auto yid = 0 @for(d, 0, DIM,@if(REDUCE>>d&1,, + xi@d * ystride@d));
|
||||
auto xid = 0 @for(d, 0, DIM, + xi@d * xstride@d);
|
||||
yp[yid] = @expand_macro(@OP, Ty, yp[yid], xp[xid]);
|
||||
yp[yid] = @expand_op(@OP, @Ty, yp[yid], @Ty, xp[xid], @Tx);
|
||||
}
|
||||
}
|
||||
(void)count, (void)rcount, (void)yshape0, (void)ystride0;
|
||||
|
|
|
@ -8,7 +8,6 @@
|
|||
#include <limits>
|
||||
#include "var.h"
|
||||
#include "ops/reindex_reduce_op.h"
|
||||
#include "ops/binary_op_defs.h"
|
||||
#include "ops/op_register.h"
|
||||
|
||||
namespace jittor {
|
||||
|
@ -112,7 +111,7 @@ void ReindexReduceOp::jit_run() {
|
|||
|
||||
@for(d, 0, XDIM, for (index_t i@d=0; i@d < xshape@d; i@d++)) {
|
||||
auto xid = @for(d, 0, XDIM, + i@d * xstride@d);
|
||||
xp[xid] = @expand_macro(init_@OP, Tx);
|
||||
xp[xid] = @expand_op(init_@OP, @Tx);
|
||||
}
|
||||
// generate d-for loop
|
||||
@for(d, 0, YDIM, for (index_t i@d=0; i@d < yshape@d; i@d++)) {
|
||||
|
@ -121,7 +120,7 @@ void ReindexReduceOp::jit_run() {
|
|||
auto xid = @for(d, 0, XDIM, + xid@d * xstride@d);
|
||||
bool check_overflow = 0 @for(d, 0, XDIM, || xid@d<0 || xid@d>=xshape@d) @for(d, 0, OSIZE, || (@expand_macro(OFD@d)));
|
||||
if (!check_overflow)
|
||||
xp[xid] = @expand_macro(@OP, Tx, xp[xid], yp[yid]);
|
||||
xp[xid] = @expand_op(@OP, @Tx, xp[xid], @Tx, yp[yid], @Tx);
|
||||
}
|
||||
}
|
||||
#endif // JIT
|
||||
|
|
|
@ -9,7 +9,6 @@
|
|||
#include "ops/setitem_op.h"
|
||||
#include "ops/getitem_op.h"
|
||||
#ifdef JIT
|
||||
#include "ops/binary_op_defs.h"
|
||||
#ifdef JIT_cuda
|
||||
#include <cuda_runtime.h>
|
||||
#include "helper_cuda.h"
|
||||
|
@ -340,12 +339,12 @@ void SetitemOp::jit_run() {
|
|||
@if(@is_def(JIT_cpu),
|
||||
@if(@strcmp(@OP,void)==0,
|
||||
op[iid] = (Ti)dp[did],
|
||||
op[iid] = @expand_macro(@OP, Ti, op[iid], dp[did])
|
||||
op[iid] = @expand_op(@OP, @Ti, op[iid], @Ti, dp[did], @Td)
|
||||
);
|
||||
,
|
||||
@if(@strcmp(@OP,void)==0, op[iid] = (Ti)dp[did],
|
||||
@if(@strcmp(@OP,add)==0, atomicAdd(&op[iid], (Ti)dp[did]),
|
||||
op[iid] = @expand_macro(@OP, Ti, op[iid], dp[did])
|
||||
op[iid] = @expand_op(@OP, @Ti, op[iid], @Ti, dp[did], @Td)
|
||||
)
|
||||
);
|
||||
)
|
||||
|
|
|
@ -8,7 +8,6 @@
|
|||
#include "misc/cpu_math.h"
|
||||
#include "var.h"
|
||||
#include "ops/unary_op.h"
|
||||
#include "ops/unary_op_defs.h"
|
||||
#include "ops/op_register.h"
|
||||
|
||||
namespace jittor {
|
||||
|
@ -688,7 +687,7 @@ void UnaryOp::jit_run() {
|
|||
auto* __restrict__ yp = y->ptr<Ty>();
|
||||
index_t num = y->num;
|
||||
for (index_t i=0; i<num; i++)
|
||||
yp[i] = @expand_macro(@OP, Ty, xp[i]);
|
||||
yp[i] = @expand_op(@OP, @Ty, xp[i], @Tx);
|
||||
}
|
||||
#endif // JIT
|
||||
|
||||
|
|
|
@ -1,84 +0,0 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
// Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "common.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
#define logical_not(T,x) (!(x))
|
||||
#define bitwise_not(T,x) (~(x))
|
||||
#define negative(T,x) (-(x))
|
||||
#ifdef JIT_cuda
|
||||
// TODO: add float64 version
|
||||
#define abs(T,x) ::abs(x)
|
||||
#define log(T,x) ::logf((T)(x))
|
||||
#define exp(T,x) ::expf((T)(x))
|
||||
#define sqrt(T,x) ::sqrtf((T)(x))
|
||||
#define round(T,x) ((T) ::roundf((x)))
|
||||
#define floor(T,x) ((T) ::floorf((x)))
|
||||
#define ceil(T,x) ((T) ::ceilf((x)))
|
||||
#define round_int(T,x) ((T) ::roundf((x)))
|
||||
#define floor_int(T,x) ((T) ::floorf((x)))
|
||||
#define ceil_int(T,x) ((T) ::ceilf((x)))
|
||||
|
||||
#define sin(T,x) ((T) ::sinf((x)))
|
||||
#define asin(T,x) ((T) ::asinf((x)))
|
||||
#define sinh(T,x) ((T) ::sinhf((x)))
|
||||
#define asinh(T,x) ((T) ::asinhf((x)))
|
||||
|
||||
#define cos(T,x) ((T) ::cosf((x)))
|
||||
#define acos(T,x) ((T) ::acosf((x)))
|
||||
#define cosh(T,x) ((T) ::coshf((x)))
|
||||
#define acosh(T,x) ((T) ::acoshf((x)))
|
||||
|
||||
#define tan(T,x) ((T) ::tanf((x)))
|
||||
#define atan(T,x) ((T) ::atanf((x)))
|
||||
#define tanh(T,x) ((T) ::tanhf((x)))
|
||||
#define atanh(T,x) ((T) ::atanhf((x)))
|
||||
|
||||
#define sigmoid(T,x) ((T) (1.0f/(1.0f+::expf((::min(T(-(x)), T(@if(@strcmp(@T,float32)==0,30,300))))))))
|
||||
|
||||
#define erf(T,x) ((T) ::erff((x)))
|
||||
#define erfinv(T,x) ((T) ::erfinvf((T)(x)))
|
||||
|
||||
#else
|
||||
#define abs(T,x) std::abs(x)
|
||||
#define log(T,x) std::log((T)(x))
|
||||
#define exp(T,x) std::exp((T)(x))
|
||||
#define sqrt(T,x) std::sqrt((T)(x))
|
||||
#define round(T,x) ((T)std::round((x)))
|
||||
#define floor(T,x) ((T)std::floor((x)))
|
||||
#define ceil(T,x) ((T)std::ceil((x)))
|
||||
#define round_int(T,x) ((T)std::round((x)))
|
||||
#define floor_int(T,x) ((T)std::floor((x)))
|
||||
#define ceil_int(T,x) ((T)std::ceil((x)))
|
||||
|
||||
#define sin(T,x) ((T) std::sin((x)))
|
||||
#define asin(T,x) ((T) std::asin((x)))
|
||||
#define sinh(T,x) ((T) std::sinh((x)))
|
||||
#define asinh(T,x) ((T) std::asinh((x)))
|
||||
|
||||
#define cos(T,x) ((T) std::cos((x)))
|
||||
#define acos(T,x) ((T) std::acos((x)))
|
||||
#define cosh(T,x) ((T) std::cosh((x)))
|
||||
#define acosh(T,x) ((T) std::acosh((x)))
|
||||
|
||||
#define tan(T,x) ((T) std::tan((x)))
|
||||
#define atan(T,x) ((T) std::atan((x)))
|
||||
#define tanh(T,x) ((T) std::tanh((x)))
|
||||
#define atanh(T,x) ((T) std::atanh((x)))
|
||||
|
||||
#define sigmoid(T,x) ((T) (1.0f/(1.0f+std::exp(std::min(T(-(x)), T(@if(@strcmp(@T,float32)==0,30,300)))))))
|
||||
|
||||
#define erf(T,x) ((T) std::erf((x)))
|
||||
#define erfinv(T,x) (jittor::_erfinv(x))
|
||||
|
||||
#endif
|
||||
|
||||
#define cast(T,x) ((T)(x))
|
||||
|
||||
} // jittor
|
|
@ -47,4 +47,18 @@ string strip(const string& s) {
|
|||
return s.substr(i,j-i);
|
||||
}
|
||||
|
||||
string format(const string& s, const vector<string>& v) {
|
||||
string ss;
|
||||
for (int i=0; i<s.size(); i++) {
|
||||
if (s[i] == '$') {
|
||||
int j = s[i+1] - '0';
|
||||
ss += v.at(j);
|
||||
i ++;
|
||||
continue;
|
||||
} else
|
||||
ss += s[i];
|
||||
}
|
||||
return ss;
|
||||
}
|
||||
|
||||
} // jittor
|
|
@ -27,4 +27,6 @@ vector<string> split(const string& s, const string& sep, int max_split=0);
|
|||
|
||||
string strip(const string& s);
|
||||
|
||||
string format(const string& s, const vector<string>& v);
|
||||
|
||||
} // jittor
|
|
@ -111,17 +111,6 @@ class TestOpCompiler(unittest.TestCase):
|
|||
check("@{a^b == 7}", "2")
|
||||
check("@{(a^b) == 7}", "1")
|
||||
check("@{b<<a == 5*4}", "1")
|
||||
check('''#include "ops/binary_op_defs.h"
|
||||
#define OP1(a, b) a+b
|
||||
OP1
|
||||
@expand_macro(OP1,1,2)
|
||||
@expand_macro(maximum, T, 1, 2)
|
||||
@expand_macro(@OP,T,1,2)''',
|
||||
''' #define OP1(a, b) a+b
|
||||
OP1
|
||||
1+2
|
||||
std::max(T(1), T(2))
|
||||
((1)+T(2)*(T(rcount)))''')
|
||||
expect_error(lambda: jit_precompile(vars, "@{a"))
|
||||
expect_error(lambda: jit_precompile(vars, "@for(a"))
|
||||
expect_error(lambda: jit_precompile(vars, "@for(i,l,r)"))
|
||||
|
|
Loading…
Reference in New Issue