fp16 support, opbytype interface

This commit is contained in:
Dun Liang 2022-02-20 20:34:49 +08:00
parent 81b847e6f0
commit bcb57086c3
16 changed files with 234 additions and 179 deletions

View File

@ -1235,6 +1235,7 @@ files4 = [ f[len(jittor_path)+1:] for f in files4 ]
# files4 = run_cmd('find -L src | grep '+grep_args, jittor_path).splitlines()
at_beginning = [
"src/ops/op_utils.cc",
"src/ops/op_register.cc",
"src/event_queue.cc",
"src/mem/allocator/sfrl_allocator.cc",
"src/mem/allocator.cc",

View File

@ -60,14 +60,14 @@ class Pool(Module):
'''
if not self.return_indices:
forward_body += f'''
@out(i0, i1, i2, i3) = init_{self.op}(out_type);
@out(i0, i1, i2, i3) = @expand_op(init_{self.op}, @out_type);
for (int p = k2; p < k2_; ++p)
for (int q = k3; q < k3_; ++q)
@out(i0, i1, i2, i3) = {self.op}(out_type, @out(i0, i1, i2, i3), @in0(i0, i1, p, q));
@out(i0, i1, i2, i3) = @expand_op({self.op}, @out_type, @out(i0, i1, i2, i3), @out_type, @in0(i0, i1, p, q), @in0_type);
'''
else:
forward_body += f'''
auto out_value = init_{self.op}(out_type);
auto out_value = @expand_op(init_{self.op}, @out_type);
int out_index = -1;
for (int p = k2; p < k2_; ++p)
for (int q = k3; q < k3_; ++q)
@ -105,7 +105,6 @@ class Pool(Module):
return_dtypes = x.dtype
out = jt.code(return_shapes, return_dtypes, [x],
cuda_header="""
#include <ops/binary_op_defs.h>
#include <misc/cuda_limits.h>
""",
cuda_src=f'''
@ -153,7 +152,7 @@ class Pool(Module):
dim3 s2_(tx, ty);
kernel3<<<s1_, s2_>>>(@ARGS);
'''],
cpu_header='#include <ops/binary_op_defs.h>',
cpu_header='',
cpu_src=f'''
using namespace std;
for (int i0=0; i0<out_shape0; i0++)
@ -242,15 +241,15 @@ class Pool3d(Module):
'''
if not self.return_indices:
forward_body += f'''
@out(i0, i1, i2, i3, i4) = init_{self.op}(out_type);
@out(i0, i1, i2, i3, i4) = @expand_op(init_{self.op}, @out_type);
for (int p = k2; p < k2_; ++p)
for (int q = k3; q < k3_; ++q)
for (int r = k4; r < k4_; ++r)
@out(i0, i1, i2, i3, i4) = {self.op}(out_type, @out(i0, i1, i2, i3, i4), @in0(i0, i1, p, q, r));
@out(i0, i1, i2, i3, i4) = @expand_op({self.op}, @out_type, @out(i0, i1, i2, i3, i4), @out_type, @in0(i0, i1, p, q, r), @in0_type);
'''
else:
forward_body += f'''
auto out_value = init_{self.op}(out_type);
auto out_value = @expand_op(init_{self.op}, @out_type);
int out_index = -1;
for (int p = k2; p < k2_; ++p)
for (int q = k3; q < k3_; ++q)
@ -293,7 +292,6 @@ class Pool3d(Module):
return_dtypes = x.dtype
out = jt.code(return_shapes, return_dtypes, [x],
cuda_header="""
#include <ops/binary_op_defs.h>
#include <misc/cuda_limits.h>
""",
cuda_src=f'''
@ -349,7 +347,7 @@ class Pool3d(Module):
dim3 s2(tx, ty, tz);
kernel3<<<s1, s2>>>(@ARGS);
'''],
cpu_header='#include <ops/binary_op_defs.h>',
cpu_header='',
cpu_src=f'''
using namespace std;
for (int i0=0; i0<out_shape0; i0++)

View File

@ -223,6 +223,16 @@ void load_macros(const string& src, unordered_map<string,string>& macros) {
}
}
string expand_op_search(const vector<string>& args) {
for (auto op_type : op_types) {
string ret = op_type->expand_op(args);
if (ret.size())
return ret;
}
LOGf << "No expand op pattern found for args:" << args;
return "";
}
void expand_macro(const string& macro, const vector<string>& args, string& new_src) {
LOGvvvv << "expand_macro" << macro << "args:" << args;
if (macro.size() == 0 || macro[0] != '<') {
@ -434,6 +444,7 @@ string precompile(unordered_map<string,string> defs, string src, unordered_map<s
vector<string> args;
size_t l = k+1;
if (expr == "for" || expr == "if" || expr == "expand_macro" ||
expr == "expand_op" ||
expr == "is_def" || expr == "python" ||
(k<src.size() && src[k]=='(')) {
ASSERT(src[k] == '(');
@ -555,6 +566,18 @@ string precompile(unordered_map<string,string> defs, string src, unordered_map<s
i = l-1;
continue;
} else
if (expr == "expand_op") {
// syntax: @expand_op(args)
for (auto& arg : args) {
uint p=0;
while (p<arg.size() && arg[p] == ' ') p++;
arg = precompile(defs, arg.substr(p), macros);
}
string ns = expand_op_search(args);
new_src += precompile(defs, ns, macros);
i = l-1;
continue;
} else
if (expr == "define") {
// syntax: @define(macro, value)
// ij k l
@ -846,6 +869,9 @@ string OpCompiler::__get_fused_src(
};
auto not_change = [&](const string& s) -> bool {
if (unchanged.count(s)) return true;
for (auto op_type : op_types)
if (op_type->types.count(s))
return true;
return (s.find("::") != string::npos) || (s.find("LOG") != string::npos);
};
// regex find XxxXxxOp::jit_run

View File

@ -8,7 +8,6 @@
#include "var.h"
#include "ops/binary_op.h"
#include "ops/broadcast_to_op.h"
#include "ops/binary_op_defs.h"
#include "ops/op_register.h"
namespace jittor {
@ -554,7 +553,7 @@ void BinaryOp::jit_run() {
auto* __restrict__ zp = z->ptr<Tz>();
index_t num = z->num;
for (index_t i=0; i<num; i++)
zp[i] = @expand_macro(@OP, Tz, xp[i], yp[i]);
zp[i] = @expand_op(@OP, @Tz, xp[i], @Tx, yp[i], @Ty);
}
#endif // JIT

View File

@ -1,61 +0,0 @@
// ***************************************************************
// Copyright (c) 2021 Jittor. All Rights Reserved.
// Maintainers: Dun Liang <randonlang@gmail.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#pragma once
#include "common.h"
namespace jittor {
#ifdef JIT_cuda
#define pow(T,a,b) ::pow(a,b)
#define maximum(T,a,b) ::max(T(a), T(b))
#define minimum(T,a,b) ::min(T(a), T(b))
#define mod(T,a,b) @if(@strcmp(@T,float32)==0,(a-::floorf((a)/(b))*(b)),@if(@strcmp(@Tx,float64)==0,(a-::floor((a)/(b))*(b)),(a%b)))
#else // JIT_cpu
#define pow(T,a,b) std::pow(a,b)
#define maximum(T,a,b) std::max(T(a), T(b))
#define minimum(T,a,b) std::min(T(a), T(b))
#define mod(T,a,b) @if(@strcmp(@T,float32)==0,(a-std::floor((a)/(b))*(b)),@if(@strcmp(@Tx,float64)==0,(a-std::floor((a)/(b))*(b)),(a%b)))
#endif
#define add(T,a,b) ((a)+(b))
#define subtract(T,a,b) ((a)-(b))
#define multiply(T,a,b) ((a)*(b))
#define divide(T,a,b) (T((T(a))/(T(b))))
#define floor_divide(T,a,b) (T((T(a))/(T(b))))
#define less(T,a,b) ((a)<(b))
#define less_equal(T,a,b) ((a)<=(b))
#define greater(T,a,b) ((a)>(b))
#define greater_equal(T,a,b) ((a)>=(b))
#define equal(T,a,b) ((a)==(b))
#define not_equal(T,a,b) ((a)!=(b))
#define left_shift(T,a,b) ((a)<<(b))
#define right_shift(T,a,b) ((a)>>(b))
#define logical_and(T,a,b) ((a)&&(b))
#define logical_or(T,a,b) ((a)||(b))
#define logical_xor(T,a,b) ((bool(a))!=(bool(b)))
#define bitwise_and(T,a,b) ((a)&(b))
#define bitwise_or(T,a,b) ((a)|(b))
#define bitwise_xor(T,a,b) ((a)^(b))
#define mean(T,a,b) ((a)+T(b)*(T(rcount)))
#ifdef JIT_cuda
#define init_maximum(T) ::numeric_min<T>()
#define init_minimum(T) ::numeric_max<T>()
#else
#define init_maximum(T) std::numeric_limits<T>::lowest()
#define init_minimum(T) std::numeric_limits<T>::max()
#endif
#define init_add(T) T(0)
#define init_multiply(T) T(1)
#define init_logical_and(T) true
#define init_logical_or(T) false
#define init_logical_xor(T) false
#define init_bitwise_and(T) T(-1)
#define init_bitwise_or(T) T(0)
#define init_bitwise_xor(T) T(0)
#define init_mean(T) T(0)
} // jittor

View File

@ -0,0 +1,159 @@
// ***************************************************************
// Copyright (c) 2021 Jittor. All Rights Reserved.
// Maintainers: Dun Liang <randonlang@gmail.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include "common.h"
#include "utils/str_utils.h"
#include "ops/op_register.h"
namespace jittor {
extern int use_cuda;
struct CommonOpType : OpByType {
CommonOpType() {
types = {
"bool",
"int8",
"int16",
"int32",
"int64",
"uint8",
"uint16",
"uint32",
"uint64",
"float32",
"float64",
};
}
string expand_op(const vector<string>& args) {
for (int i=1; i<args.size(); i+=2) {
if (!types.count(args[i]))
return "";
}
static unordered_map<string,string> cuda_map = {
{"logical_not", "(!($2))"},
{"bitwise_not", "(~($2))"},
{"negative", "(-($2))"},
{"abs", "::abs($2)"},
{"log", "::logf(($1)($2))"},
{"exp", "::expf(($1)($2))"},
{"sqrt", "::sqrtf(($1)($2))"},
{"round", "(($1) ::roundf(($2)))"},
{"floor", "(($1) ::floorf(($2)))"},
{"ceil", "(($1) ::ceilf(($2)))"},
{"round_int", "(($1) ::roundf(($2)))"},
{"floor_int", "(($1) ::floorf(($2)))"},
{"ceil_int", "(($1) ::ceilf(($2)))"},
{"sin", "(($1) ::sinf(($2)))"},
{"asin", "(($1) ::asinf(($2)))"},
{"sinh", "(($1) ::sinhf(($2)))"},
{"asinh", "(($1) ::asinhf(($2)))"},
{"cos", "(($1) ::cosf(($2)))"},
{"acos", "(($1) ::acosf(($2)))"},
{"cosh", "(($1) ::coshf(($2)))"},
{"acosh", "(($1) ::acoshf(($2)))"},
{"tan", "(($1) ::tanf(($2)))"},
{"atan", "(($1) ::atanf(($2)))"},
{"tanh", "(($1) ::tanhf(($2)))"},
{"atanh", "(($1) ::atanhf(($2)))"},
{"sigmoid", "(($1) (1.0f/(1.0f+::expf((::min($1(-($2)), $1(@if(@strcmp($1,float32)==0,30,300))))))))"},
{"erf", "(($1) ::erff(($2)))"},
{"erfinv", "(($1) ::erfinvf(($1)($2)))"},
{"cast", "(($1)($2))"},
{"pow", "::pow(($2),($4))"},
{"maximum", "::max($1($2), $1($4))"},
{"minimum", "::min($1($2), $1($4))"},
{"mod", "@if(@strcmp($1,float32)==0,(($2)-::floorf(($2)/($4))*($4)),@if(@strcmp(@Tx,float64)==0,(($2)-::floor(($2)/($4))*($4)),(($2)%($4))))"},
{"init_maximum", "::numeric_min<$1>()"},
{"init_minimum", "::numeric_max<$1>()"},
};
static unordered_map<string,string> cpu_map = {
{"logical_not", "(!($2))"},
{"bitwise_not", "(~($2))"},
{"negative", "(-($2))"},
{"abs", "std::abs($2)"},
{"log", "std::log(($1)($2))"},
{"exp", "std::exp(($1)($2))"},
{"sqrt", "std::sqrt(($1)($2))"},
{"round", "(($1)std::round(($2)))"},
{"floor", "(($1)std::floor(($2)))"},
{"ceil", "(($1)std::ceil(($2)))"},
{"round_int", "(($1)std::round(($2)))"},
{"floor_int", "(($1)std::floor(($2)))"},
{"ceil_int", "(($1)std::ceil(($2)))"},
{"sin", "(($1) std::sin(($2)))"},
{"asin", "(($1) std::asin(($2)))"},
{"sinh", "(($1) std::sinh(($2)))"},
{"asinh", "(($1) std::asinh(($2)))"},
{"cos", "(($1) std::cos(($2)))"},
{"acos", "(($1) std::acos(($2)))"},
{"cosh", "(($1) std::cosh(($2)))"},
{"acosh", "(($1) std::acosh(($2)))"},
{"tan", "(($1) std::tan(($2)))"},
{"atan", "(($1) std::atan(($2)))"},
{"tanh", "(($1) std::tanh(($2)))"},
{"atanh", "(($1) std::atanh(($2)))"},
{"sigmoid", "(($1) (1.0f/(1.0f+std::exp(std::min($1(-($2)), $1(@if(@strcmp($1,float32)==0,30,300)))))))"},
{"erf", "(($1) std::erf(($2)))"},
{"erfinv", "(jittor::_erfinv($2))"},
{"cast", "(($1)($2))"},
{"pow", "std::pow(($2),($4))"},
{"maximum", "std::max($1($2), $1($4))"},
{"minimum", "std::min($1($2), $1($4))"},
{"mod", "@if(@strcmp($1,float32)==0,(($2)-std::floor(($2)/($4))*($4)),@if(@strcmp(@Tx,float64)==0,(($2)-std::floor(($2)/($4))*($4)),(($2)%($4))))"},
{"init_maximum", "std::numeric_limits<$1>::lowest()"},
{"init_minimum", "std::numeric_limits<$1>::max()"},
};
static unordered_map<string,string> both_map {
{"add", "(($2)+($4))"},
{"subtract", "(($2)-($4))"},
{"multiply", "(($2)*($4))"},
{"divide", "($1(($1($2))/($1($4))))"},
{"floor_divide", "($1(($1($2))/($1($4))))"},
{"less", "(($2)<($4))"},
{"less_equal", "(($2)<=($4))"},
{"greater", "(($2)>($4))"},
{"greater_equal", "(($2)>=($4))"},
{"equal", "(($2)==($4))"},
{"not_equal", "(($2)!=($4))"},
{"left_shift", "(($2)<<($4))"},
{"right_shift", "(($2)>>($4))"},
{"logical_and", "(($2)&&($4))"},
{"logical_or", "(($2)||($4))"},
{"logical_xor", "((bool($2))!=(bool($4)))"},
{"bitwise_and", "(($2)&($4))"},
{"bitwise_or", "(($2)|($4))"},
{"bitwise_xor", "(($2)^($4))"},
{"mean", "(($2)+$1($4)*($1(rcount)))"},
{"init_add", "$1(0)"},
{"init_multiply", "$1(1)"},
{"init_logical_and", "true"},
{"init_logical_or", "false"},
{"init_logical_xor", "false"},
{"init_bitwise_and", "$1(-1)"},
{"init_bitwise_or", "$1(0)"},
{"init_bitwise_xor", "$1(0)"},
{"init_mean", "$1(0)"},
};
string ret;
if (both_map.count(args.at(0)))
ret = both_map[args.at(0)];
else if (use_cuda)
ret = cuda_map[args.at(0)];
else
ret = cpu_map[args.at(0)];
return format(ret, args);
}
};
static int _ = registe_op_type(new CommonOpType());
}

View File

@ -33,4 +33,12 @@ OpInfo get_op_info(const string& name) {
return op_info_map.at(op_file_name);
}
vector<OpByType*> op_types;
int registe_op_type(OpByType* op_type) {
op_types.push_back(op_type);
return 0;
}
} // jittor

View File

@ -32,4 +32,12 @@ void op_registe(const OpInfo& op_info);
bool has_op(const string& name);
OpInfo get_op_info(const string& name);
struct OpByType {
unordered_set<string> types;
virtual string expand_op(const vector<string>& args) = 0;
};
extern vector<OpByType*> op_types;
int registe_op_type(OpByType*);
} // jittor

View File

@ -8,7 +8,6 @@
#include <limits>
#include "var.h"
#include "ops/reduce_op.h"
#include "ops/binary_op_defs.h"
#include "ops/op_register.h"
#include "executor.h"
@ -364,14 +363,14 @@ void ReduceOp::jit_run() {
Ty rcount = Ty(y->num) / Ty(x->num);
@for(d, 0, DIM,@if(REDUCE>>d&1,, for (index_t xi@d=0; xi@d < xshape@d; xi@d++))) {
auto yid = 0 @for(d, 0, DIM,@if(REDUCE>>d&1,, + xi@d * ystride@d));
yp[yid] = @expand_macro(init_@OP, Ty);
yp[yid] = @expand_op(init_@OP, @Ty);
}
@for(d, 0, DIM,@if(REDUCE>>d&1,, for (index_t xi@d=0; xi@d < xshape@d; xi@d++))) {
@for(d, 0, DIM,@if(REDUCE>>d&1, for (index_t xi@d=0; xi@d < xshape@d; xi@d++),)) {
auto yid = 0 @for(d, 0, DIM,@if(REDUCE>>d&1,, + xi@d * ystride@d));
auto xid = 0 @for(d, 0, DIM, + xi@d * xstride@d);
yp[yid] = @expand_macro(@OP, Ty, yp[yid], xp[xid]);
yp[yid] = @expand_op(@OP, @Ty, yp[yid], @Ty, xp[xid], @Tx);
}
}
(void)count, (void)rcount, (void)yshape0, (void)ystride0;

View File

@ -8,7 +8,6 @@
#include <limits>
#include "var.h"
#include "ops/reindex_reduce_op.h"
#include "ops/binary_op_defs.h"
#include "ops/op_register.h"
namespace jittor {
@ -112,7 +111,7 @@ void ReindexReduceOp::jit_run() {
@for(d, 0, XDIM, for (index_t i@d=0; i@d < xshape@d; i@d++)) {
auto xid = @for(d, 0, XDIM, + i@d * xstride@d);
xp[xid] = @expand_macro(init_@OP, Tx);
xp[xid] = @expand_op(init_@OP, @Tx);
}
// generate d-for loop
@for(d, 0, YDIM, for (index_t i@d=0; i@d < yshape@d; i@d++)) {
@ -121,7 +120,7 @@ void ReindexReduceOp::jit_run() {
auto xid = @for(d, 0, XDIM, + xid@d * xstride@d);
bool check_overflow = 0 @for(d, 0, XDIM, || xid@d<0 || xid@d>=xshape@d) @for(d, 0, OSIZE, || (@expand_macro(OFD@d)));
if (!check_overflow)
xp[xid] = @expand_macro(@OP, Tx, xp[xid], yp[yid]);
xp[xid] = @expand_op(@OP, @Tx, xp[xid], @Tx, yp[yid], @Tx);
}
}
#endif // JIT

View File

@ -9,7 +9,6 @@
#include "ops/setitem_op.h"
#include "ops/getitem_op.h"
#ifdef JIT
#include "ops/binary_op_defs.h"
#ifdef JIT_cuda
#include <cuda_runtime.h>
#include "helper_cuda.h"
@ -340,12 +339,12 @@ void SetitemOp::jit_run() {
@if(@is_def(JIT_cpu),
@if(@strcmp(@OP,void)==0,
op[iid] = (Ti)dp[did],
op[iid] = @expand_macro(@OP, Ti, op[iid], dp[did])
op[iid] = @expand_op(@OP, @Ti, op[iid], @Ti, dp[did], @Td)
);
,
@if(@strcmp(@OP,void)==0, op[iid] = (Ti)dp[did],
@if(@strcmp(@OP,add)==0, atomicAdd(&op[iid], (Ti)dp[did]),
op[iid] = @expand_macro(@OP, Ti, op[iid], dp[did])
op[iid] = @expand_op(@OP, @Ti, op[iid], @Ti, dp[did], @Td)
)
);
)

View File

@ -8,7 +8,6 @@
#include "misc/cpu_math.h"
#include "var.h"
#include "ops/unary_op.h"
#include "ops/unary_op_defs.h"
#include "ops/op_register.h"
namespace jittor {
@ -688,7 +687,7 @@ void UnaryOp::jit_run() {
auto* __restrict__ yp = y->ptr<Ty>();
index_t num = y->num;
for (index_t i=0; i<num; i++)
yp[i] = @expand_macro(@OP, Ty, xp[i]);
yp[i] = @expand_op(@OP, @Ty, xp[i], @Tx);
}
#endif // JIT

View File

@ -1,84 +0,0 @@
// ***************************************************************
// Copyright (c) 2021 Jittor. All Rights Reserved.
// Maintainers: Dun Liang <randonlang@gmail.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#pragma once
#include "common.h"
namespace jittor {
#define logical_not(T,x) (!(x))
#define bitwise_not(T,x) (~(x))
#define negative(T,x) (-(x))
#ifdef JIT_cuda
// TODO: add float64 version
#define abs(T,x) ::abs(x)
#define log(T,x) ::logf((T)(x))
#define exp(T,x) ::expf((T)(x))
#define sqrt(T,x) ::sqrtf((T)(x))
#define round(T,x) ((T) ::roundf((x)))
#define floor(T,x) ((T) ::floorf((x)))
#define ceil(T,x) ((T) ::ceilf((x)))
#define round_int(T,x) ((T) ::roundf((x)))
#define floor_int(T,x) ((T) ::floorf((x)))
#define ceil_int(T,x) ((T) ::ceilf((x)))
#define sin(T,x) ((T) ::sinf((x)))
#define asin(T,x) ((T) ::asinf((x)))
#define sinh(T,x) ((T) ::sinhf((x)))
#define asinh(T,x) ((T) ::asinhf((x)))
#define cos(T,x) ((T) ::cosf((x)))
#define acos(T,x) ((T) ::acosf((x)))
#define cosh(T,x) ((T) ::coshf((x)))
#define acosh(T,x) ((T) ::acoshf((x)))
#define tan(T,x) ((T) ::tanf((x)))
#define atan(T,x) ((T) ::atanf((x)))
#define tanh(T,x) ((T) ::tanhf((x)))
#define atanh(T,x) ((T) ::atanhf((x)))
#define sigmoid(T,x) ((T) (1.0f/(1.0f+::expf((::min(T(-(x)), T(@if(@strcmp(@T,float32)==0,30,300))))))))
#define erf(T,x) ((T) ::erff((x)))
#define erfinv(T,x) ((T) ::erfinvf((T)(x)))
#else
#define abs(T,x) std::abs(x)
#define log(T,x) std::log((T)(x))
#define exp(T,x) std::exp((T)(x))
#define sqrt(T,x) std::sqrt((T)(x))
#define round(T,x) ((T)std::round((x)))
#define floor(T,x) ((T)std::floor((x)))
#define ceil(T,x) ((T)std::ceil((x)))
#define round_int(T,x) ((T)std::round((x)))
#define floor_int(T,x) ((T)std::floor((x)))
#define ceil_int(T,x) ((T)std::ceil((x)))
#define sin(T,x) ((T) std::sin((x)))
#define asin(T,x) ((T) std::asin((x)))
#define sinh(T,x) ((T) std::sinh((x)))
#define asinh(T,x) ((T) std::asinh((x)))
#define cos(T,x) ((T) std::cos((x)))
#define acos(T,x) ((T) std::acos((x)))
#define cosh(T,x) ((T) std::cosh((x)))
#define acosh(T,x) ((T) std::acosh((x)))
#define tan(T,x) ((T) std::tan((x)))
#define atan(T,x) ((T) std::atan((x)))
#define tanh(T,x) ((T) std::tanh((x)))
#define atanh(T,x) ((T) std::atanh((x)))
#define sigmoid(T,x) ((T) (1.0f/(1.0f+std::exp(std::min(T(-(x)), T(@if(@strcmp(@T,float32)==0,30,300)))))))
#define erf(T,x) ((T) std::erf((x)))
#define erfinv(T,x) (jittor::_erfinv(x))
#endif
#define cast(T,x) ((T)(x))
} // jittor

View File

@ -47,4 +47,18 @@ string strip(const string& s) {
return s.substr(i,j-i);
}
string format(const string& s, const vector<string>& v) {
string ss;
for (int i=0; i<s.size(); i++) {
if (s[i] == '$') {
int j = s[i+1] - '0';
ss += v.at(j);
i ++;
continue;
} else
ss += s[i];
}
return ss;
}
} // jittor

View File

@ -27,4 +27,6 @@ vector<string> split(const string& s, const string& sep, int max_split=0);
string strip(const string& s);
string format(const string& s, const vector<string>& v);
} // jittor

View File

@ -111,17 +111,6 @@ class TestOpCompiler(unittest.TestCase):
check("@{a^b == 7}", "2")
check("@{(a^b) == 7}", "1")
check("@{b<<a == 5*4}", "1")
check('''#include "ops/binary_op_defs.h"
#define OP1(a, b) a+b
OP1
@expand_macro(OP1,1,2)
@expand_macro(maximum, T, 1, 2)
@expand_macro(@OP,T,1,2)''',
''' #define OP1(a, b) a+b
OP1
1+2
std::max(T(1), T(2))
((1)+T(2)*(T(rcount)))''')
expect_error(lambda: jit_precompile(vars, "@{a"))
expect_error(lambda: jit_precompile(vars, "@for(a"))
expect_error(lambda: jit_precompile(vars, "@for(i,l,r)"))