mirror of https://github.com/Jittor/Jittor
optimize concat and split
This commit is contained in:
parent
54fb38caed
commit
9c74699707
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.3.2.1'
|
||||
__version__ = '1.3.2.2'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -231,6 +231,7 @@ def _merge_dtypes(dtypes):
|
|||
dtype = names[s]+("" if e ==-1 else dbytes[e])
|
||||
return dtype
|
||||
|
||||
@jt.flag_scope(amp_reg=4) # _custom_flag
|
||||
def concat(arr, dim=0):
|
||||
'''Concat Operator can concat a list of jt Var at a specfic dimension.
|
||||
|
||||
|
|
|
@ -34,7 +34,7 @@ class Generator(nn.Module):
|
|||
nn.Tanh())
|
||||
|
||||
def execute(self, noise, labels):
|
||||
gen_input = jt.contrib.concat((self.label_emb(labels), noise), dim=1)
|
||||
gen_input = jt.concat((self.label_emb(labels), noise), dim=1)
|
||||
img = self.model(gen_input)
|
||||
img = img.view((img.shape[0], *img_shape))
|
||||
return img
|
||||
|
@ -55,7 +55,7 @@ class Discriminator(nn.Module):
|
|||
nn.Linear(512, 1))
|
||||
|
||||
def execute(self, img, labels):
|
||||
d_in = jt.contrib.concat((img.view((img.shape[0], (- 1))), self.label_embedding(labels)), dim=1)
|
||||
d_in = jt.concat((img.view((img.shape[0], (- 1))), self.label_embedding(labels)), dim=1)
|
||||
validity = self.model(d_in)
|
||||
return validity
|
||||
|
||||
|
|
|
@ -121,8 +121,12 @@ void CublasBatchedMatmulOp::jit_run() {
|
|||
#if CUDART_VERSION >= 11000
|
||||
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
|
||||
cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
|
||||
if (use_tensorcore) {
|
||||
if (use_tensorcore>=3) {
|
||||
computeType = CUBLAS_COMPUTE_32F_FAST_16F;
|
||||
} else if (use_tensorcore==2) {
|
||||
computeType = CUBLAS_COMPUTE_32F_FAST_16BF;
|
||||
} else if (use_tensorcore==1) {
|
||||
computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
|
||||
}
|
||||
if (a->dtype() == ns_float16
|
||||
|| b->dtype() == ns_float16 || c->dtype() == ns_float16) {
|
||||
|
|
|
@ -78,8 +78,12 @@ void CublasMatmulOp::jit_run() {
|
|||
#if CUDART_VERSION >= 11000
|
||||
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
|
||||
cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
|
||||
if (use_tensorcore) {
|
||||
if (use_tensorcore>=3) {
|
||||
computeType = CUBLAS_COMPUTE_32F_FAST_16F;
|
||||
} else if (use_tensorcore==2) {
|
||||
computeType = CUBLAS_COMPUTE_32F_FAST_16BF;
|
||||
} else if (use_tensorcore==1) {
|
||||
computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
|
||||
}
|
||||
if (a->dtype() == ns_float16
|
||||
|| b->dtype() == ns_float16 || c->dtype() == ns_float16) {
|
||||
|
|
|
@ -257,7 +257,7 @@ def stack(x, dim=0):
|
|||
return x[0].unsqueeze(dim)
|
||||
|
||||
res = [x_.unsqueeze(dim) for x_ in x]
|
||||
return jt.contrib.concat(res, dim=dim)
|
||||
return jt.concat(res, dim=dim)
|
||||
jt.Var.stack = stack
|
||||
|
||||
def flip(x, dim=0):
|
||||
|
@ -342,7 +342,7 @@ def cross(input, other, dim=-1):
|
|||
a1 = input[(slice(None,),)*dim+(1,)]*other[(slice(None,),)*dim+(2,)]-input[(slice(None,),)*dim+(2,)]*other[(slice(None,),)*dim+(1,)]
|
||||
a2 = input[(slice(None,),)*dim+(2,)]*other[(slice(None,),)*dim+(0,)]-input[(slice(None,),)*dim+(0,)]*other[(slice(None,),)*dim+(2,)]
|
||||
a3 = input[(slice(None,),)*dim+(0,)]*other[(slice(None,),)*dim+(1,)]-input[(slice(None,),)*dim+(1,)]*other[(slice(None,),)*dim+(0,)]
|
||||
return jt.contrib.concat([a1.unsqueeze(dim),a2.unsqueeze(dim),a3.unsqueeze(dim)], dim=dim)
|
||||
return jt.concat([a1.unsqueeze(dim),a2.unsqueeze(dim),a3.unsqueeze(dim)], dim=dim)
|
||||
jt.Var.cross = cross
|
||||
|
||||
def normalize(input, p=2, dim=1, eps=1e-30):
|
||||
|
@ -465,7 +465,7 @@ def unique(x):
|
|||
_,x = jt.argsort(x)
|
||||
index,= jt.index((x.shape[0],))
|
||||
y = x[1:][x[index[1:]] != x[index[:-1]]]
|
||||
x = jt.contrib.concat([x[:1],y],dim=0)
|
||||
x = jt.concat([x[:1],y],dim=0)
|
||||
return x
|
||||
|
||||
jt.Var.unique = unique
|
||||
|
@ -501,7 +501,7 @@ def nonzero(x):
|
|||
x = [xx.unsqueeze(1) for xx in x]
|
||||
if len(x)<2:
|
||||
return x[0]
|
||||
x = jt.contrib.concat(x,dim=1)
|
||||
x = jt.concat(x,dim=1)
|
||||
return x
|
||||
|
||||
jt.Var.nonzero = nonzero
|
||||
|
@ -546,7 +546,7 @@ def meshgrid(*tensors):
|
|||
return grids
|
||||
|
||||
|
||||
def split(d,split_size,dim):
|
||||
def split(d, split_size, dim=0):
|
||||
r'''
|
||||
Splits the tensor into chunks. Each chunk is a view of the original tensor.
|
||||
|
||||
|
@ -575,7 +575,9 @@ def split(d,split_size,dim):
|
|||
|
||||
ans = []
|
||||
last = 0
|
||||
for i in split_size:
|
||||
s_last = len(split_size)-1
|
||||
gopt_disable = jt.flags.gopt_disable
|
||||
for j, i in enumerate(split_size):
|
||||
if i==0:
|
||||
shape = list(d.shape)
|
||||
shape[dim]=0
|
||||
|
@ -584,7 +586,11 @@ def split(d,split_size,dim):
|
|||
continue
|
||||
|
||||
ss = (slice(None),)*dim+(slice(last,last+i),)
|
||||
new_d = d[ss]
|
||||
if gopt_disable:
|
||||
new_d = d.getitem(ss)
|
||||
else:
|
||||
new_d, d = d.getitem(ss, int(j==s_last))
|
||||
|
||||
last +=i
|
||||
ans.append(new_d)
|
||||
return tuple(ans)
|
||||
|
|
|
@ -76,7 +76,7 @@ class _DenseLayer(nn.Sequential):
|
|||
new_features = super(_DenseLayer, self).execute(x)
|
||||
if (self.drop_rate > 0):
|
||||
new_features = self.drop(new_features)
|
||||
return jt.contrib.concat([x, new_features], dim=1)
|
||||
return jt.concat([x, new_features], dim=1)
|
||||
|
||||
class _DenseBlock(nn.Sequential):
|
||||
|
||||
|
|
|
@ -121,7 +121,7 @@ class Inception(nn.Module):
|
|||
|
||||
def execute(self, x):
|
||||
outputs = self._forward(x)
|
||||
return jt.contrib.concat(outputs, dim=1)
|
||||
return jt.concat(outputs, dim=1)
|
||||
|
||||
class InceptionAux(nn.Module):
|
||||
|
||||
|
|
|
@ -118,7 +118,7 @@ class InceptionA(nn.Module):
|
|||
|
||||
def execute(self, x):
|
||||
outputs = self._forward(x)
|
||||
return jt.contrib.concat(outputs, dim=1)
|
||||
return jt.concat(outputs, dim=1)
|
||||
|
||||
class InceptionB(nn.Module):
|
||||
|
||||
|
@ -142,7 +142,7 @@ class InceptionB(nn.Module):
|
|||
|
||||
def execute(self, x):
|
||||
outputs = self._forward(x)
|
||||
return jt.contrib.concat(outputs, dim=1)
|
||||
return jt.concat(outputs, dim=1)
|
||||
|
||||
class InceptionC(nn.Module):
|
||||
|
||||
|
@ -179,7 +179,7 @@ class InceptionC(nn.Module):
|
|||
|
||||
def execute(self, x):
|
||||
outputs = self._forward(x)
|
||||
return jt.contrib.concat(outputs, dim=1)
|
||||
return jt.concat(outputs, dim=1)
|
||||
|
||||
class InceptionD(nn.Module):
|
||||
|
||||
|
@ -207,7 +207,7 @@ class InceptionD(nn.Module):
|
|||
|
||||
def execute(self, x):
|
||||
outputs = self._forward(x)
|
||||
return jt.contrib.concat(outputs, dim=1)
|
||||
return jt.concat(outputs, dim=1)
|
||||
|
||||
class InceptionE(nn.Module):
|
||||
|
||||
|
@ -229,11 +229,11 @@ class InceptionE(nn.Module):
|
|||
branch1x1 = self.branch1x1(x)
|
||||
branch3x3 = self.branch3x3_1(x)
|
||||
branch3x3 = [self.branch3x3_2a(branch3x3), self.branch3x3_2b(branch3x3)]
|
||||
branch3x3 = jt.contrib.concat(branch3x3, dim=1)
|
||||
branch3x3 = jt.concat(branch3x3, dim=1)
|
||||
branch3x3dbl = self.branch3x3dbl_1(x)
|
||||
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
||||
branch3x3dbl = [self.branch3x3dbl_3a(branch3x3dbl), self.branch3x3dbl_3b(branch3x3dbl)]
|
||||
branch3x3dbl = jt.contrib.concat(branch3x3dbl, dim=1)
|
||||
branch3x3dbl = jt.concat(branch3x3dbl, dim=1)
|
||||
branch_pool = nn.pool(x, kernel_size=3, op="mean", stride=1, padding=1)
|
||||
branch_pool = self.branch_pool(branch_pool)
|
||||
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
||||
|
@ -241,7 +241,7 @@ class InceptionE(nn.Module):
|
|||
|
||||
def execute(self, x):
|
||||
outputs = self._forward(x)
|
||||
return jt.contrib.concat(outputs, dim=1)
|
||||
return jt.concat(outputs, dim=1)
|
||||
|
||||
class InceptionAux(nn.Module):
|
||||
|
||||
|
|
|
@ -45,9 +45,9 @@ class InvertedResidual(nn.Module):
|
|||
if (self.stride == 1):
|
||||
x1 = x[:,0:x.shape[1]//2]
|
||||
x2 = x[:,x.shape[1]//2:x.shape[1]]
|
||||
out = jt.contrib.concat([x1, self.branch2(x2)], dim=1)
|
||||
out = jt.concat([x1, self.branch2(x2)], dim=1)
|
||||
else:
|
||||
out = jt.contrib.concat([self.branch1(x), self.branch2(x)], dim=1)
|
||||
out = jt.concat([self.branch1(x), self.branch2(x)], dim=1)
|
||||
out = channel_shuffle(out, 2)
|
||||
return out
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@ class Fire(nn.Module):
|
|||
|
||||
def execute(self, x):
|
||||
x = self.squeeze_activation(self.squeeze(x))
|
||||
return jt.contrib.concat([self.expand1x1_activation(self.expand1x1(x)), self.expand3x3_activation(self.expand3x3(x))], dim=1)
|
||||
return jt.concat([self.expand1x1_activation(self.expand1x1(x)), self.expand3x3_activation(self.expand3x3(x))], dim=1)
|
||||
|
||||
class SqueezeNet(nn.Module):
|
||||
|
||||
|
|
|
@ -80,7 +80,7 @@ class Generator(nn.Module):
|
|||
nn.Tanh())
|
||||
|
||||
def execute(self, noise, labels):
|
||||
gen_input = jt.contrib.concat((self.label_emb(labels), noise), dim=1)
|
||||
gen_input = jt.concat((self.label_emb(labels), noise), dim=1)
|
||||
img = self.model(gen_input)
|
||||
img = img.view((img.shape[0], *img_shape))
|
||||
return img
|
||||
|
@ -105,7 +105,7 @@ class Discriminator(nn.Module):
|
|||
nn.Linear(512, 1))
|
||||
|
||||
def execute(self, img, labels):
|
||||
d_in = jt.contrib.concat((img.view((img.shape[0], (- 1))), self.label_embedding(labels)), dim=1)
|
||||
d_in = jt.concat((img.view((img.shape[0], (- 1))), self.label_embedding(labels)), dim=1)
|
||||
validity = self.model(d_in)
|
||||
return validity
|
||||
```
|
||||
|
|
|
@ -31,6 +31,8 @@ pytype_map = {
|
|||
"int": ["PyLong_AsLong", "PyLong_FromLong", "PyLong_CheckExact"],
|
||||
"int64": ["PyLong_AsLongLong", "PyLong_FromLongLong", "PyLong_CheckExact"],
|
||||
"uint": ["PyLong_AsUnsignedLong", "PyLong_FromUnsignedLong", "PyLong_CheckExact"],
|
||||
"uint8": ["PyLong_AsUnsignedLong", "PyLong_FromUnsignedLong", "PyLong_CheckExact"],
|
||||
"uint16": ["PyLong_AsUnsignedLong", "PyLong_FromUnsignedLong", "PyLong_CheckExact"],
|
||||
"uint64": ["PyLong_AsUnsignedLongLong", "PyLong_FromUnsignedLongLong", "PyLong_CheckExact"],
|
||||
"void": ["...", "GET_PY_NONE", "..."],
|
||||
"PyObject*": ["","",""],
|
||||
|
|
|
@ -29,7 +29,7 @@ class SparseVar:
|
|||
def t(self):
|
||||
indices = list(self.indices.split(1,dim=0))
|
||||
indices[-1],indices[-2] = indices[-2],indices[-1]
|
||||
indices = jt.contrib.concat(indices,dim=0)
|
||||
indices = jt.concat(indices,dim=0)
|
||||
shape = list(self.shape)
|
||||
shape[-1],shape[-2] = shape[-2],shape[-1]
|
||||
shape = jt.NanoVector(shape)
|
||||
|
|
|
@ -132,4 +132,46 @@ void clean_graph() {
|
|||
}
|
||||
}
|
||||
|
||||
void check_circle(Node* s) {
|
||||
vector<Node*> q = {s};
|
||||
vector<int> fa = {-1};
|
||||
unordered_set<Node*> visited = {s};
|
||||
for (int i=0; i<q.size(); i++) {
|
||||
auto n = q[i];
|
||||
for (auto o : n->outputs()) {
|
||||
if (o == s) {
|
||||
LOGe << "Found circle:";
|
||||
int j=i;
|
||||
vector<Node*> nodes{o};
|
||||
while (j) {
|
||||
nodes.push_back(q[j]);
|
||||
j = fa[j];
|
||||
}
|
||||
for (int i=0; i<nodes.size(); i++) {
|
||||
auto n = nodes[i];
|
||||
auto out = nodes[(i-1+nodes.size())%nodes.size()];
|
||||
auto in = nodes[(i+1)%nodes.size()];
|
||||
int in_id=0, out_id=0;
|
||||
for (auto ii : n->inputs()) {
|
||||
if (ii == in) break;
|
||||
in_id ++;
|
||||
}
|
||||
for (auto oo : n->outputs()) {
|
||||
if (oo == out) break;
|
||||
out_id ++;
|
||||
}
|
||||
LOGe << n << "in:" >> in_id >> '/' >> n->inputs().size() << "out:" >> out_id >> '/' >> n->outputs().size();
|
||||
}
|
||||
LOGf << "found circle";
|
||||
}
|
||||
if (!visited.count(o)) {
|
||||
visited.emplace(o);
|
||||
q.push_back(o);
|
||||
fa.push_back(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
} // jittor
|
|
@ -148,4 +148,6 @@ void toplogical_sort_backward(vector<Node*>& nodes, vector<Node*>& sorted, Func&
|
|||
ASSERTop(nodes.size(),==,sorted.size());
|
||||
}
|
||||
|
||||
void check_circle(Node* s);
|
||||
|
||||
} // jittor
|
|
@ -9,17 +9,6 @@
|
|||
|
||||
namespace jittor {
|
||||
|
||||
DEFINE_FLAG(int, amp_reg, 0, "Auto mixed-precision control registers, bit 0: prefer 32; bit 1: prefer 16; bit 2: keep reduce type; bit 3 keep white list type; bit 4: array like op prefer too");
|
||||
|
||||
DEFINE_FLAG_WITH_SETTER(int, auto_mixed_precision_level, 0, "Auto mixed-precision optimization level, 0: not use fp16, 1-3: preserve level, not use fp16 for now; 4: perfer fp16, but some ops use fp32 e.g. sum,exp; 5: simular with 4, and array op will automatically convert to fp16; 6: all ops prefer fp16");
|
||||
|
||||
void setter_auto_mixed_precision_level(int value) {
|
||||
if (value <= 3) amp_reg = 0; else
|
||||
if (value == 4) amp_reg = amp_prefer16; else
|
||||
if (value == 5) amp_reg = amp_prefer16 | amp_array_prefer; else
|
||||
if (value == 6) amp_reg = amp_prefer16 | amp_array_prefer | amp_keep_reduce | amp_keep_white;
|
||||
}
|
||||
|
||||
#define FOR_ALL_TYPES(m) \
|
||||
m(bool) \
|
||||
m(int8) \
|
||||
|
|
|
@ -59,6 +59,7 @@ struct NodeFlags {
|
|||
_prefer_16=_n+10,
|
||||
// bit11: reduce keep type unchange
|
||||
_reduce_keep=_n+11,
|
||||
_custom_flag=_reduce_keep,
|
||||
};
|
||||
|
||||
inline void set(Flags f, int a=1, int nbits=1) {
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
// ***************************************************************
|
||||
#include <cmath>
|
||||
#include "var.h"
|
||||
#include "executor.h"
|
||||
#include "ops/getitem_op.h"
|
||||
#include "ops/op_register.h"
|
||||
#ifdef JIT_cuda
|
||||
|
@ -27,6 +28,8 @@ namespace jittor {
|
|||
|
||||
static auto make_number = get_op_info("number")
|
||||
.get_constructor<VarPtr, float, Var*>();
|
||||
static auto make_empty = get_op_info("empty")
|
||||
.get_constructor<VarPtr, NanoVector, NanoString>();
|
||||
static auto make_setitem = get_op_info("setitem")
|
||||
.get_constructor<VarPtr, Var*, VarSlices&&, Var*, NanoString>();
|
||||
|
||||
|
@ -38,6 +41,19 @@ GetitemOp::GetitemOp(Var* x, VarSlices&& slices)
|
|||
create_output(nullptr, x->dtype());
|
||||
}
|
||||
|
||||
GetitemOp::GetitemOp(Var* x, VarSlices&& slices, int _)
|
||||
: vs(move(slices)) {
|
||||
flags.set(NodeFlags::_cpu);
|
||||
flags.set(NodeFlags::_cuda);
|
||||
flags.set(NodeFlags::_has_gopt);
|
||||
flags.set(NodeFlags::_custom_flag);
|
||||
flags.set(NodeFlags::_grads);
|
||||
create_output(nullptr, x->dtype());
|
||||
auto out2 = create_output(nullptr, x->dtype());
|
||||
out2->share_with(x);
|
||||
ns.data = _;
|
||||
}
|
||||
|
||||
void GetitemOp::infer_slices(
|
||||
StackVector<>& __restrict__ i_to_vs,
|
||||
StackVector<>& __restrict__ i_to_o,
|
||||
|
@ -377,6 +393,10 @@ void GetitemOp::infer_shape() {
|
|||
this->i_to_vs = i_to_vs.to_nano_vector();
|
||||
this->i_to_o = i_to_o.to_nano_vector();
|
||||
this->o_shape = o_shape.to_nano_vector();
|
||||
if (outputs().size() > 1) {
|
||||
auto out2 = output(1);
|
||||
out2->set_shape(in->shape);
|
||||
}
|
||||
|
||||
LOGV(999) << "\ni_to_vs:" << i_to_vs
|
||||
<< "\ni_to_o:" << i_to_o
|
||||
|
@ -396,6 +416,24 @@ VarPtr GetitemOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
|||
return make_setitem(zeros, VarSlices(vs, true), dout, ns_void);
|
||||
}
|
||||
|
||||
void GetitemOp::grads(Var** dout, VarPtr* dins) {
|
||||
VarPtr x = dout[1];
|
||||
VarPtr y = dout[0];
|
||||
if (!x) {
|
||||
auto in = inputs().front();
|
||||
if (in->num<0) exe.run_sync(vector<Var*>({in}), true);
|
||||
// ns.data represents this is the last split var
|
||||
if (ns.data)
|
||||
x = make_empty(in->shape, in->dtype());
|
||||
else
|
||||
x = make_number(0, in);
|
||||
}
|
||||
if (!y) {
|
||||
y = make_number(0, outputs().front());
|
||||
}
|
||||
dins[0] = make_setitem(x, VarSlices(vs, true), y, ns_void);
|
||||
}
|
||||
|
||||
void GetitemOp::jit_prepare(JK& jk) {
|
||||
auto in = inputs().front();
|
||||
int idim = i_to_vs.size();
|
||||
|
|
|
@ -22,9 +22,12 @@ struct GetitemOp : Op {
|
|||
int first_oid_of_var, var_dim;
|
||||
|
||||
GetitemOp(Var* x, VarSlices&& slices);
|
||||
// @attrs(multiple_outputs)
|
||||
GetitemOp(Var* x, VarSlices&& slices, int _);
|
||||
|
||||
const char* name() const override { return "getitem"; }
|
||||
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
|
||||
void grads(Var** dout, VarPtr* dins) override;
|
||||
void infer_shape() override;
|
||||
void compile_optimize(string& src) override;
|
||||
void graph_optimize() override;
|
||||
|
|
|
@ -28,6 +28,8 @@ static auto make_array = get_op_info("array")
|
|||
.get_constructor<VarPtr, const void*, NanoVector, NanoString>();
|
||||
static auto make_getitem = get_op_info("getitem")
|
||||
.get_constructor<VarPtr, Var*, VarSlices&&>();
|
||||
static auto make_getitem2 = get_op_info("getitem")
|
||||
.get_constructor<vector<VarPtr>, Var*, VarSlices&&, int>();
|
||||
static auto make_setitem = get_op_info("setitem")
|
||||
.get_constructor<VarPtr, Var*, VarSlices&&, Var*, NanoString>();
|
||||
static auto make_binary = get_op_info("binary")
|
||||
|
@ -40,8 +42,11 @@ SetitemOp::SetitemOp(Var* x, VarSlices&& slices, Var* y, NanoString op)
|
|||
flags.set(NodeFlags::_cpu);
|
||||
flags.set(NodeFlags::_cuda);
|
||||
flags.set(NodeFlags::_has_gopt);
|
||||
ASSERT(ns == ns_void || ns.is_binary());
|
||||
ASSERT(op == ns_void || op.is_binary());
|
||||
create_output(nullptr, x->dtype());
|
||||
if (flags.get(NodeFlags::_custom_flag)) {
|
||||
flags.set(NodeFlags::_grads);
|
||||
}
|
||||
}
|
||||
|
||||
void SetitemOp::infer_shape() {
|
||||
|
@ -120,6 +125,12 @@ void SetitemOp::infer_shape() {
|
|||
<< "\no_shape:" << o_shape;
|
||||
}
|
||||
|
||||
void SetitemOp::grads(Var** dout, VarPtr* dins) {
|
||||
auto outs = make_getitem2(dout[0], VarSlices(vs, true), 0);
|
||||
dins[0] = move(outs[1]);
|
||||
dins[1] = move(outs[0]);
|
||||
}
|
||||
|
||||
VarPtr SetitemOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||
if (v_index >= 2)
|
||||
return nullptr;
|
||||
|
|
|
@ -25,6 +25,7 @@ struct SetitemOp : Op {
|
|||
|
||||
const char* name() const override { return "setitem"; }
|
||||
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
|
||||
void grads(Var** dout, VarPtr* dins) override;
|
||||
void infer_shape() override;
|
||||
void compile_optimize(string& src) override;
|
||||
void graph_optimize() override;
|
||||
|
|
|
@ -73,6 +73,9 @@ static void setitem_inplace(SetitemOp* op) {
|
|||
return;
|
||||
if (data->allocator)
|
||||
return;
|
||||
auto data_op = data->input();
|
||||
if (data_op->flags.get(NodeFlags::_custom_flag))
|
||||
return;
|
||||
|
||||
auto in_shape = input->shape;
|
||||
int64 inplace_size = 1;
|
||||
|
|
|
@ -22,6 +22,16 @@ DEFINE_FLAG(bool, no_grad, 0,
|
|||
"No grad for all jittor Var creation");
|
||||
DEFINE_FLAG(bool, no_fuse, 0,
|
||||
"No fusion optimization for all jittor Var creation");
|
||||
DEFINE_FLAG(int, amp_reg, 0, "Auto mixed-precision control registers, bit 0: prefer 32; bit 1: prefer 16; bit 2: keep reduce type; bit 3 keep white list type; bit 4: array like op prefer too");
|
||||
|
||||
DEFINE_FLAG_WITH_SETTER(int, auto_mixed_precision_level, 0, "Auto mixed-precision optimization level, 0: not use fp16, 1-3: preserve level, not use fp16 for now; 4: perfer fp16, but some ops use fp32 e.g. sum,exp; 5: simular with 4, and array op will automatically convert to fp16; 6: all ops prefer fp16");
|
||||
|
||||
void setter_auto_mixed_precision_level(int value) {
|
||||
if (value <= 3) amp_reg = 0; else
|
||||
if (value == 4) amp_reg = amp_prefer16; else
|
||||
if (value == 5) amp_reg = amp_prefer16 | amp_array_prefer; else
|
||||
if (value == 6) amp_reg = amp_prefer16 | amp_array_prefer | amp_keep_reduce | amp_keep_white;
|
||||
}
|
||||
|
||||
Var::Var(NanoVector shape, NanoString dtype)
|
||||
: shape(shape),
|
||||
|
|
|
@ -48,7 +48,7 @@ class TestConcatOp(unittest.TestCase):
|
|||
def test_concat_op(self):
|
||||
def check(tmp, dim=0):
|
||||
res1 = numpy_concat(tmp, dim=dim)
|
||||
res2 = jt.contrib.concat(tmp, dim=dim)
|
||||
res2 = jt.concat(tmp, dim=dim)
|
||||
assert (res2!=res1).data.sum()==0, "concat fail..."
|
||||
check([jt.array([[1],[2]]), jt.array([[2],[2]])])
|
||||
check([jt.array(np.array(range(24))).reshape((1,2,3,4)), jt.array(np.array(range(24))).reshape((1,2,3,4))])
|
||||
|
@ -68,7 +68,7 @@ class TestConcatOp(unittest.TestCase):
|
|||
arr = []
|
||||
for i in range(m):
|
||||
arr.append(a[(slice(None),)*dim + (slice(i*size,i*size+size),)])
|
||||
b = jt.contrib.concat(arr, dim)
|
||||
b = jt.concat(arr, dim)
|
||||
if backward:
|
||||
loss = b * a
|
||||
b = jt.grad(loss, a)
|
||||
|
|
|
@ -20,7 +20,7 @@ class TestContrib(unittest.TestCase):
|
|||
arr1.append(a)
|
||||
arr2.append(jt.array(a))
|
||||
x = np.concatenate(tuple(arr1), dim)
|
||||
y = jt.contrib.concat(arr2, dim)
|
||||
y = jt.concat(arr2, dim)
|
||||
assert (x==y.data).all(), (x, y.data, arr1, arr2)
|
||||
check([2,3,4], 0, 2)
|
||||
check([2,3,4], 1, 3)
|
||||
|
|
|
@ -104,7 +104,7 @@ class TestSingleArray(unittest.TestCase):
|
|||
arr1.append(a)
|
||||
arr2.append(jt.array(a))
|
||||
x = np.concatenate(tuple(arr1), dim)
|
||||
y = jt.contrib.concat(arr2, dim)
|
||||
y = jt.concat(arr2, dim)
|
||||
assert (x==y.data).all()
|
||||
check([1], 0, 20)
|
||||
|
||||
|
|
|
@ -39,7 +39,7 @@ class TestSetitem(unittest.TestCase):
|
|||
|
||||
arr21 = jt.ones((2,2))
|
||||
arr22 = jt.ones((2,2)) * 2
|
||||
arr2 = jt.contrib.concat([arr21, arr22], dim=0)
|
||||
arr2 = jt.concat([arr21, arr22], dim=0)
|
||||
arr2.sync()
|
||||
arr21.data[0,0] = 3
|
||||
arr22.data[0,0] = 4
|
||||
|
@ -283,14 +283,17 @@ class TestSetitem(unittest.TestCase):
|
|||
np.concatenate([b.data,c.data]))
|
||||
|
||||
def test_concat_random(self):
|
||||
def check():
|
||||
def check(backward=False):
|
||||
n1, n2, n3 = 1000, 20, 10
|
||||
# n1, n2, n3 = 2, 2, 1
|
||||
# n1, n2, n3 = 3, 2, 3
|
||||
import random
|
||||
data = []
|
||||
back = []
|
||||
for i in range(n1):
|
||||
if len(data) > n2:
|
||||
del data[random.randint(0,len(data)-1)]
|
||||
v = random.randint(0,len(data)-1)
|
||||
# print("del", v)
|
||||
del data[v]
|
||||
x1 = random.randint(0,9)
|
||||
# print(i, x1)
|
||||
if len(data) == 0:
|
||||
|
@ -319,7 +322,15 @@ class TestSetitem(unittest.TestCase):
|
|||
elif x1 == 4:
|
||||
# a = jt.random((random.randint(10,20),))
|
||||
a = jt.array(np.random.rand(random.randint(n3,n3*2)))
|
||||
if backward and random.randint(0,1):
|
||||
back.append(a)
|
||||
data.append(a)
|
||||
elif x1 == 5:
|
||||
v = random.randint(0,len(data)-1)
|
||||
a = data[v]
|
||||
# print("split", v, a.shape)
|
||||
arr = a.split(n3-1)
|
||||
data += arr
|
||||
else:
|
||||
if not len(data): continue
|
||||
n = random.randint(1,3)
|
||||
|
@ -329,19 +340,49 @@ class TestSetitem(unittest.TestCase):
|
|||
b = np.random.permutation(np.arange(a.numel()))
|
||||
a = a[b][:100]
|
||||
data.append(a)
|
||||
ret = jt.concat(data).numpy()
|
||||
# print(data)
|
||||
return ret
|
||||
ret = jt.concat(data)
|
||||
if backward and len(back):
|
||||
grads = jt.grad(jt.rand_like(ret)*ret, back)
|
||||
return jt.concat(grads).numpy()
|
||||
return ret.numpy()
|
||||
|
||||
for s in range(100):
|
||||
print("check", s)
|
||||
for check_grad in [True, False]:
|
||||
jt.set_global_seed(s)
|
||||
data = check(check_grad)
|
||||
jt.gc()
|
||||
jt.set_global_seed(s)
|
||||
with jt.flag_scope(gopt_disable=1):
|
||||
data2 = check(check_grad)
|
||||
jt.gc()
|
||||
np.testing.assert_allclose(data, data2, atol=1e-5, rtol=1e-5)
|
||||
|
||||
def test_concat_grad(self):
|
||||
n = 30000
|
||||
m = 100
|
||||
arr = []
|
||||
for i in range(n):
|
||||
arr.append(jt.random((m,)))
|
||||
x = jt.concat(arr)
|
||||
y = jt.rand_like(x)
|
||||
grads = jt.grad(x*y, arr)
|
||||
for i in range(n):
|
||||
np.testing.assert_allclose(grads[i].numpy(), y[i*m:(i+1)*m].numpy())
|
||||
|
||||
def test_split_grad(self):
|
||||
n = 30000
|
||||
m = 100
|
||||
x = jt.random((n*m,))
|
||||
arr = x.split(m)
|
||||
yy = [ jt.rand(m) for i in range(n) ]
|
||||
arr2 = [ y*yy[i] for i,y in enumerate(arr) ]
|
||||
g = jt.grad(jt.concat(arr2), x)
|
||||
for i in range(n):
|
||||
np.testing.assert_allclose(g.data[i*m:(i+1)*m], yy[i].data)
|
||||
|
||||
|
||||
|
||||
for s in range(1000):
|
||||
jt.set_global_seed(s)
|
||||
data = check()
|
||||
jt.gc()
|
||||
jt.set_global_seed(s)
|
||||
with jt.flag_scope(gopt_disable=1):
|
||||
data2 = check()
|
||||
jt.gc()
|
||||
np.testing.assert_allclose(data, data2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -560,6 +560,15 @@ try:
|
|||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
import sys
|
||||
sys.setrecursionlimit(10**6)
|
||||
if os.name != 'nt':
|
||||
import resource
|
||||
resource.setrlimit(resource.RLIMIT_STACK, (2**29,-1))
|
||||
except:
|
||||
pass
|
||||
|
||||
if os.name == 'nt':
|
||||
if check_msvc_install:
|
||||
if not os.path.isfile(cc_path):
|
||||
|
|
Loading…
Reference in New Issue