optimize concat and split

This commit is contained in:
Dun Liang 2022-04-04 22:53:20 +08:00
parent 54fb38caed
commit 9c74699707
29 changed files with 226 additions and 59 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.3.2.1'
__version__ = '1.3.2.2'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int

View File

@ -231,6 +231,7 @@ def _merge_dtypes(dtypes):
dtype = names[s]+("" if e ==-1 else dbytes[e])
return dtype
@jt.flag_scope(amp_reg=4) # _custom_flag
def concat(arr, dim=0):
'''Concat Operator can concat a list of jt Var at a specfic dimension.

View File

@ -34,7 +34,7 @@ class Generator(nn.Module):
nn.Tanh())
def execute(self, noise, labels):
gen_input = jt.contrib.concat((self.label_emb(labels), noise), dim=1)
gen_input = jt.concat((self.label_emb(labels), noise), dim=1)
img = self.model(gen_input)
img = img.view((img.shape[0], *img_shape))
return img
@ -55,7 +55,7 @@ class Discriminator(nn.Module):
nn.Linear(512, 1))
def execute(self, img, labels):
d_in = jt.contrib.concat((img.view((img.shape[0], (- 1))), self.label_embedding(labels)), dim=1)
d_in = jt.concat((img.view((img.shape[0], (- 1))), self.label_embedding(labels)), dim=1)
validity = self.model(d_in)
return validity

View File

@ -121,8 +121,12 @@ void CublasBatchedMatmulOp::jit_run() {
#if CUDART_VERSION >= 11000
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
if (use_tensorcore) {
if (use_tensorcore>=3) {
computeType = CUBLAS_COMPUTE_32F_FAST_16F;
} else if (use_tensorcore==2) {
computeType = CUBLAS_COMPUTE_32F_FAST_16BF;
} else if (use_tensorcore==1) {
computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
}
if (a->dtype() == ns_float16
|| b->dtype() == ns_float16 || c->dtype() == ns_float16) {

View File

@ -78,8 +78,12 @@ void CublasMatmulOp::jit_run() {
#if CUDART_VERSION >= 11000
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
if (use_tensorcore) {
if (use_tensorcore>=3) {
computeType = CUBLAS_COMPUTE_32F_FAST_16F;
} else if (use_tensorcore==2) {
computeType = CUBLAS_COMPUTE_32F_FAST_16BF;
} else if (use_tensorcore==1) {
computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
}
if (a->dtype() == ns_float16
|| b->dtype() == ns_float16 || c->dtype() == ns_float16) {

View File

@ -257,7 +257,7 @@ def stack(x, dim=0):
return x[0].unsqueeze(dim)
res = [x_.unsqueeze(dim) for x_ in x]
return jt.contrib.concat(res, dim=dim)
return jt.concat(res, dim=dim)
jt.Var.stack = stack
def flip(x, dim=0):
@ -342,7 +342,7 @@ def cross(input, other, dim=-1):
a1 = input[(slice(None,),)*dim+(1,)]*other[(slice(None,),)*dim+(2,)]-input[(slice(None,),)*dim+(2,)]*other[(slice(None,),)*dim+(1,)]
a2 = input[(slice(None,),)*dim+(2,)]*other[(slice(None,),)*dim+(0,)]-input[(slice(None,),)*dim+(0,)]*other[(slice(None,),)*dim+(2,)]
a3 = input[(slice(None,),)*dim+(0,)]*other[(slice(None,),)*dim+(1,)]-input[(slice(None,),)*dim+(1,)]*other[(slice(None,),)*dim+(0,)]
return jt.contrib.concat([a1.unsqueeze(dim),a2.unsqueeze(dim),a3.unsqueeze(dim)], dim=dim)
return jt.concat([a1.unsqueeze(dim),a2.unsqueeze(dim),a3.unsqueeze(dim)], dim=dim)
jt.Var.cross = cross
def normalize(input, p=2, dim=1, eps=1e-30):
@ -465,7 +465,7 @@ def unique(x):
_,x = jt.argsort(x)
index,= jt.index((x.shape[0],))
y = x[1:][x[index[1:]] != x[index[:-1]]]
x = jt.contrib.concat([x[:1],y],dim=0)
x = jt.concat([x[:1],y],dim=0)
return x
jt.Var.unique = unique
@ -501,7 +501,7 @@ def nonzero(x):
x = [xx.unsqueeze(1) for xx in x]
if len(x)<2:
return x[0]
x = jt.contrib.concat(x,dim=1)
x = jt.concat(x,dim=1)
return x
jt.Var.nonzero = nonzero
@ -546,7 +546,7 @@ def meshgrid(*tensors):
return grids
def split(d,split_size,dim):
def split(d, split_size, dim=0):
r'''
Splits the tensor into chunks. Each chunk is a view of the original tensor.
@ -575,7 +575,9 @@ def split(d,split_size,dim):
ans = []
last = 0
for i in split_size:
s_last = len(split_size)-1
gopt_disable = jt.flags.gopt_disable
for j, i in enumerate(split_size):
if i==0:
shape = list(d.shape)
shape[dim]=0
@ -584,7 +586,11 @@ def split(d,split_size,dim):
continue
ss = (slice(None),)*dim+(slice(last,last+i),)
new_d = d[ss]
if gopt_disable:
new_d = d.getitem(ss)
else:
new_d, d = d.getitem(ss, int(j==s_last))
last +=i
ans.append(new_d)
return tuple(ans)

View File

@ -76,7 +76,7 @@ class _DenseLayer(nn.Sequential):
new_features = super(_DenseLayer, self).execute(x)
if (self.drop_rate > 0):
new_features = self.drop(new_features)
return jt.contrib.concat([x, new_features], dim=1)
return jt.concat([x, new_features], dim=1)
class _DenseBlock(nn.Sequential):

View File

@ -121,7 +121,7 @@ class Inception(nn.Module):
def execute(self, x):
outputs = self._forward(x)
return jt.contrib.concat(outputs, dim=1)
return jt.concat(outputs, dim=1)
class InceptionAux(nn.Module):

View File

@ -118,7 +118,7 @@ class InceptionA(nn.Module):
def execute(self, x):
outputs = self._forward(x)
return jt.contrib.concat(outputs, dim=1)
return jt.concat(outputs, dim=1)
class InceptionB(nn.Module):
@ -142,7 +142,7 @@ class InceptionB(nn.Module):
def execute(self, x):
outputs = self._forward(x)
return jt.contrib.concat(outputs, dim=1)
return jt.concat(outputs, dim=1)
class InceptionC(nn.Module):
@ -179,7 +179,7 @@ class InceptionC(nn.Module):
def execute(self, x):
outputs = self._forward(x)
return jt.contrib.concat(outputs, dim=1)
return jt.concat(outputs, dim=1)
class InceptionD(nn.Module):
@ -207,7 +207,7 @@ class InceptionD(nn.Module):
def execute(self, x):
outputs = self._forward(x)
return jt.contrib.concat(outputs, dim=1)
return jt.concat(outputs, dim=1)
class InceptionE(nn.Module):
@ -229,11 +229,11 @@ class InceptionE(nn.Module):
branch1x1 = self.branch1x1(x)
branch3x3 = self.branch3x3_1(x)
branch3x3 = [self.branch3x3_2a(branch3x3), self.branch3x3_2b(branch3x3)]
branch3x3 = jt.contrib.concat(branch3x3, dim=1)
branch3x3 = jt.concat(branch3x3, dim=1)
branch3x3dbl = self.branch3x3dbl_1(x)
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
branch3x3dbl = [self.branch3x3dbl_3a(branch3x3dbl), self.branch3x3dbl_3b(branch3x3dbl)]
branch3x3dbl = jt.contrib.concat(branch3x3dbl, dim=1)
branch3x3dbl = jt.concat(branch3x3dbl, dim=1)
branch_pool = nn.pool(x, kernel_size=3, op="mean", stride=1, padding=1)
branch_pool = self.branch_pool(branch_pool)
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
@ -241,7 +241,7 @@ class InceptionE(nn.Module):
def execute(self, x):
outputs = self._forward(x)
return jt.contrib.concat(outputs, dim=1)
return jt.concat(outputs, dim=1)
class InceptionAux(nn.Module):

View File

@ -45,9 +45,9 @@ class InvertedResidual(nn.Module):
if (self.stride == 1):
x1 = x[:,0:x.shape[1]//2]
x2 = x[:,x.shape[1]//2:x.shape[1]]
out = jt.contrib.concat([x1, self.branch2(x2)], dim=1)
out = jt.concat([x1, self.branch2(x2)], dim=1)
else:
out = jt.contrib.concat([self.branch1(x), self.branch2(x)], dim=1)
out = jt.concat([self.branch1(x), self.branch2(x)], dim=1)
out = channel_shuffle(out, 2)
return out

View File

@ -26,7 +26,7 @@ class Fire(nn.Module):
def execute(self, x):
x = self.squeeze_activation(self.squeeze(x))
return jt.contrib.concat([self.expand1x1_activation(self.expand1x1(x)), self.expand3x3_activation(self.expand3x3(x))], dim=1)
return jt.concat([self.expand1x1_activation(self.expand1x1(x)), self.expand3x3_activation(self.expand3x3(x))], dim=1)
class SqueezeNet(nn.Module):

View File

@ -80,7 +80,7 @@ class Generator(nn.Module):
nn.Tanh())
def execute(self, noise, labels):
gen_input = jt.contrib.concat((self.label_emb(labels), noise), dim=1)
gen_input = jt.concat((self.label_emb(labels), noise), dim=1)
img = self.model(gen_input)
img = img.view((img.shape[0], *img_shape))
return img
@ -105,7 +105,7 @@ class Discriminator(nn.Module):
nn.Linear(512, 1))
def execute(self, img, labels):
d_in = jt.contrib.concat((img.view((img.shape[0], (- 1))), self.label_embedding(labels)), dim=1)
d_in = jt.concat((img.view((img.shape[0], (- 1))), self.label_embedding(labels)), dim=1)
validity = self.model(d_in)
return validity
```

View File

@ -31,6 +31,8 @@ pytype_map = {
"int": ["PyLong_AsLong", "PyLong_FromLong", "PyLong_CheckExact"],
"int64": ["PyLong_AsLongLong", "PyLong_FromLongLong", "PyLong_CheckExact"],
"uint": ["PyLong_AsUnsignedLong", "PyLong_FromUnsignedLong", "PyLong_CheckExact"],
"uint8": ["PyLong_AsUnsignedLong", "PyLong_FromUnsignedLong", "PyLong_CheckExact"],
"uint16": ["PyLong_AsUnsignedLong", "PyLong_FromUnsignedLong", "PyLong_CheckExact"],
"uint64": ["PyLong_AsUnsignedLongLong", "PyLong_FromUnsignedLongLong", "PyLong_CheckExact"],
"void": ["...", "GET_PY_NONE", "..."],
"PyObject*": ["","",""],

View File

@ -29,7 +29,7 @@ class SparseVar:
def t(self):
indices = list(self.indices.split(1,dim=0))
indices[-1],indices[-2] = indices[-2],indices[-1]
indices = jt.contrib.concat(indices,dim=0)
indices = jt.concat(indices,dim=0)
shape = list(self.shape)
shape[-1],shape[-2] = shape[-2],shape[-1]
shape = jt.NanoVector(shape)

View File

@ -132,4 +132,46 @@ void clean_graph() {
}
}
void check_circle(Node* s) {
vector<Node*> q = {s};
vector<int> fa = {-1};
unordered_set<Node*> visited = {s};
for (int i=0; i<q.size(); i++) {
auto n = q[i];
for (auto o : n->outputs()) {
if (o == s) {
LOGe << "Found circle:";
int j=i;
vector<Node*> nodes{o};
while (j) {
nodes.push_back(q[j]);
j = fa[j];
}
for (int i=0; i<nodes.size(); i++) {
auto n = nodes[i];
auto out = nodes[(i-1+nodes.size())%nodes.size()];
auto in = nodes[(i+1)%nodes.size()];
int in_id=0, out_id=0;
for (auto ii : n->inputs()) {
if (ii == in) break;
in_id ++;
}
for (auto oo : n->outputs()) {
if (oo == out) break;
out_id ++;
}
LOGe << n << "in:" >> in_id >> '/' >> n->inputs().size() << "out:" >> out_id >> '/' >> n->outputs().size();
}
LOGf << "found circle";
}
if (!visited.count(o)) {
visited.emplace(o);
q.push_back(o);
fa.push_back(i);
}
}
}
}
} // jittor

View File

@ -148,4 +148,6 @@ void toplogical_sort_backward(vector<Node*>& nodes, vector<Node*>& sorted, Func&
ASSERTop(nodes.size(),==,sorted.size());
}
void check_circle(Node* s);
} // jittor

View File

@ -9,17 +9,6 @@
namespace jittor {
DEFINE_FLAG(int, amp_reg, 0, "Auto mixed-precision control registers, bit 0: prefer 32; bit 1: prefer 16; bit 2: keep reduce type; bit 3 keep white list type; bit 4: array like op prefer too");
DEFINE_FLAG_WITH_SETTER(int, auto_mixed_precision_level, 0, "Auto mixed-precision optimization level, 0: not use fp16, 1-3: preserve level, not use fp16 for now; 4: perfer fp16, but some ops use fp32 e.g. sum,exp; 5: simular with 4, and array op will automatically convert to fp16; 6: all ops prefer fp16");
void setter_auto_mixed_precision_level(int value) {
if (value <= 3) amp_reg = 0; else
if (value == 4) amp_reg = amp_prefer16; else
if (value == 5) amp_reg = amp_prefer16 | amp_array_prefer; else
if (value == 6) amp_reg = amp_prefer16 | amp_array_prefer | amp_keep_reduce | amp_keep_white;
}
#define FOR_ALL_TYPES(m) \
m(bool) \
m(int8) \

View File

@ -59,6 +59,7 @@ struct NodeFlags {
_prefer_16=_n+10,
// bit11: reduce keep type unchange
_reduce_keep=_n+11,
_custom_flag=_reduce_keep,
};
inline void set(Flags f, int a=1, int nbits=1) {

View File

@ -6,6 +6,7 @@
// ***************************************************************
#include <cmath>
#include "var.h"
#include "executor.h"
#include "ops/getitem_op.h"
#include "ops/op_register.h"
#ifdef JIT_cuda
@ -27,6 +28,8 @@ namespace jittor {
static auto make_number = get_op_info("number")
.get_constructor<VarPtr, float, Var*>();
static auto make_empty = get_op_info("empty")
.get_constructor<VarPtr, NanoVector, NanoString>();
static auto make_setitem = get_op_info("setitem")
.get_constructor<VarPtr, Var*, VarSlices&&, Var*, NanoString>();
@ -38,6 +41,19 @@ GetitemOp::GetitemOp(Var* x, VarSlices&& slices)
create_output(nullptr, x->dtype());
}
GetitemOp::GetitemOp(Var* x, VarSlices&& slices, int _)
: vs(move(slices)) {
flags.set(NodeFlags::_cpu);
flags.set(NodeFlags::_cuda);
flags.set(NodeFlags::_has_gopt);
flags.set(NodeFlags::_custom_flag);
flags.set(NodeFlags::_grads);
create_output(nullptr, x->dtype());
auto out2 = create_output(nullptr, x->dtype());
out2->share_with(x);
ns.data = _;
}
void GetitemOp::infer_slices(
StackVector<>& __restrict__ i_to_vs,
StackVector<>& __restrict__ i_to_o,
@ -377,6 +393,10 @@ void GetitemOp::infer_shape() {
this->i_to_vs = i_to_vs.to_nano_vector();
this->i_to_o = i_to_o.to_nano_vector();
this->o_shape = o_shape.to_nano_vector();
if (outputs().size() > 1) {
auto out2 = output(1);
out2->set_shape(in->shape);
}
LOGV(999) << "\ni_to_vs:" << i_to_vs
<< "\ni_to_o:" << i_to_o
@ -396,6 +416,24 @@ VarPtr GetitemOp::grad(Var* out, Var* dout, Var* v, int v_index) {
return make_setitem(zeros, VarSlices(vs, true), dout, ns_void);
}
void GetitemOp::grads(Var** dout, VarPtr* dins) {
VarPtr x = dout[1];
VarPtr y = dout[0];
if (!x) {
auto in = inputs().front();
if (in->num<0) exe.run_sync(vector<Var*>({in}), true);
// ns.data represents this is the last split var
if (ns.data)
x = make_empty(in->shape, in->dtype());
else
x = make_number(0, in);
}
if (!y) {
y = make_number(0, outputs().front());
}
dins[0] = make_setitem(x, VarSlices(vs, true), y, ns_void);
}
void GetitemOp::jit_prepare(JK& jk) {
auto in = inputs().front();
int idim = i_to_vs.size();

View File

@ -22,9 +22,12 @@ struct GetitemOp : Op {
int first_oid_of_var, var_dim;
GetitemOp(Var* x, VarSlices&& slices);
// @attrs(multiple_outputs)
GetitemOp(Var* x, VarSlices&& slices, int _);
const char* name() const override { return "getitem"; }
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
void grads(Var** dout, VarPtr* dins) override;
void infer_shape() override;
void compile_optimize(string& src) override;
void graph_optimize() override;

View File

@ -28,6 +28,8 @@ static auto make_array = get_op_info("array")
.get_constructor<VarPtr, const void*, NanoVector, NanoString>();
static auto make_getitem = get_op_info("getitem")
.get_constructor<VarPtr, Var*, VarSlices&&>();
static auto make_getitem2 = get_op_info("getitem")
.get_constructor<vector<VarPtr>, Var*, VarSlices&&, int>();
static auto make_setitem = get_op_info("setitem")
.get_constructor<VarPtr, Var*, VarSlices&&, Var*, NanoString>();
static auto make_binary = get_op_info("binary")
@ -40,8 +42,11 @@ SetitemOp::SetitemOp(Var* x, VarSlices&& slices, Var* y, NanoString op)
flags.set(NodeFlags::_cpu);
flags.set(NodeFlags::_cuda);
flags.set(NodeFlags::_has_gopt);
ASSERT(ns == ns_void || ns.is_binary());
ASSERT(op == ns_void || op.is_binary());
create_output(nullptr, x->dtype());
if (flags.get(NodeFlags::_custom_flag)) {
flags.set(NodeFlags::_grads);
}
}
void SetitemOp::infer_shape() {
@ -120,6 +125,12 @@ void SetitemOp::infer_shape() {
<< "\no_shape:" << o_shape;
}
void SetitemOp::grads(Var** dout, VarPtr* dins) {
auto outs = make_getitem2(dout[0], VarSlices(vs, true), 0);
dins[0] = move(outs[1]);
dins[1] = move(outs[0]);
}
VarPtr SetitemOp::grad(Var* out, Var* dout, Var* v, int v_index) {
if (v_index >= 2)
return nullptr;

View File

@ -25,6 +25,7 @@ struct SetitemOp : Op {
const char* name() const override { return "setitem"; }
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
void grads(Var** dout, VarPtr* dins) override;
void infer_shape() override;
void compile_optimize(string& src) override;
void graph_optimize() override;

View File

@ -73,6 +73,9 @@ static void setitem_inplace(SetitemOp* op) {
return;
if (data->allocator)
return;
auto data_op = data->input();
if (data_op->flags.get(NodeFlags::_custom_flag))
return;
auto in_shape = input->shape;
int64 inplace_size = 1;

View File

@ -22,6 +22,16 @@ DEFINE_FLAG(bool, no_grad, 0,
"No grad for all jittor Var creation");
DEFINE_FLAG(bool, no_fuse, 0,
"No fusion optimization for all jittor Var creation");
DEFINE_FLAG(int, amp_reg, 0, "Auto mixed-precision control registers, bit 0: prefer 32; bit 1: prefer 16; bit 2: keep reduce type; bit 3 keep white list type; bit 4: array like op prefer too");
DEFINE_FLAG_WITH_SETTER(int, auto_mixed_precision_level, 0, "Auto mixed-precision optimization level, 0: not use fp16, 1-3: preserve level, not use fp16 for now; 4: perfer fp16, but some ops use fp32 e.g. sum,exp; 5: simular with 4, and array op will automatically convert to fp16; 6: all ops prefer fp16");
void setter_auto_mixed_precision_level(int value) {
if (value <= 3) amp_reg = 0; else
if (value == 4) amp_reg = amp_prefer16; else
if (value == 5) amp_reg = amp_prefer16 | amp_array_prefer; else
if (value == 6) amp_reg = amp_prefer16 | amp_array_prefer | amp_keep_reduce | amp_keep_white;
}
Var::Var(NanoVector shape, NanoString dtype)
: shape(shape),

View File

@ -48,7 +48,7 @@ class TestConcatOp(unittest.TestCase):
def test_concat_op(self):
def check(tmp, dim=0):
res1 = numpy_concat(tmp, dim=dim)
res2 = jt.contrib.concat(tmp, dim=dim)
res2 = jt.concat(tmp, dim=dim)
assert (res2!=res1).data.sum()==0, "concat fail..."
check([jt.array([[1],[2]]), jt.array([[2],[2]])])
check([jt.array(np.array(range(24))).reshape((1,2,3,4)), jt.array(np.array(range(24))).reshape((1,2,3,4))])
@ -68,7 +68,7 @@ class TestConcatOp(unittest.TestCase):
arr = []
for i in range(m):
arr.append(a[(slice(None),)*dim + (slice(i*size,i*size+size),)])
b = jt.contrib.concat(arr, dim)
b = jt.concat(arr, dim)
if backward:
loss = b * a
b = jt.grad(loss, a)

View File

@ -20,7 +20,7 @@ class TestContrib(unittest.TestCase):
arr1.append(a)
arr2.append(jt.array(a))
x = np.concatenate(tuple(arr1), dim)
y = jt.contrib.concat(arr2, dim)
y = jt.concat(arr2, dim)
assert (x==y.data).all(), (x, y.data, arr1, arr2)
check([2,3,4], 0, 2)
check([2,3,4], 1, 3)

View File

@ -104,7 +104,7 @@ class TestSingleArray(unittest.TestCase):
arr1.append(a)
arr2.append(jt.array(a))
x = np.concatenate(tuple(arr1), dim)
y = jt.contrib.concat(arr2, dim)
y = jt.concat(arr2, dim)
assert (x==y.data).all()
check([1], 0, 20)

View File

@ -39,7 +39,7 @@ class TestSetitem(unittest.TestCase):
arr21 = jt.ones((2,2))
arr22 = jt.ones((2,2)) * 2
arr2 = jt.contrib.concat([arr21, arr22], dim=0)
arr2 = jt.concat([arr21, arr22], dim=0)
arr2.sync()
arr21.data[0,0] = 3
arr22.data[0,0] = 4
@ -283,14 +283,17 @@ class TestSetitem(unittest.TestCase):
np.concatenate([b.data,c.data]))
def test_concat_random(self):
def check():
def check(backward=False):
n1, n2, n3 = 1000, 20, 10
# n1, n2, n3 = 2, 2, 1
# n1, n2, n3 = 3, 2, 3
import random
data = []
back = []
for i in range(n1):
if len(data) > n2:
del data[random.randint(0,len(data)-1)]
v = random.randint(0,len(data)-1)
# print("del", v)
del data[v]
x1 = random.randint(0,9)
# print(i, x1)
if len(data) == 0:
@ -319,7 +322,15 @@ class TestSetitem(unittest.TestCase):
elif x1 == 4:
# a = jt.random((random.randint(10,20),))
a = jt.array(np.random.rand(random.randint(n3,n3*2)))
if backward and random.randint(0,1):
back.append(a)
data.append(a)
elif x1 == 5:
v = random.randint(0,len(data)-1)
a = data[v]
# print("split", v, a.shape)
arr = a.split(n3-1)
data += arr
else:
if not len(data): continue
n = random.randint(1,3)
@ -329,19 +340,49 @@ class TestSetitem(unittest.TestCase):
b = np.random.permutation(np.arange(a.numel()))
a = a[b][:100]
data.append(a)
ret = jt.concat(data).numpy()
# print(data)
return ret
ret = jt.concat(data)
if backward and len(back):
grads = jt.grad(jt.rand_like(ret)*ret, back)
return jt.concat(grads).numpy()
return ret.numpy()
for s in range(100):
print("check", s)
for check_grad in [True, False]:
jt.set_global_seed(s)
data = check(check_grad)
jt.gc()
jt.set_global_seed(s)
with jt.flag_scope(gopt_disable=1):
data2 = check(check_grad)
jt.gc()
np.testing.assert_allclose(data, data2, atol=1e-5, rtol=1e-5)
def test_concat_grad(self):
n = 30000
m = 100
arr = []
for i in range(n):
arr.append(jt.random((m,)))
x = jt.concat(arr)
y = jt.rand_like(x)
grads = jt.grad(x*y, arr)
for i in range(n):
np.testing.assert_allclose(grads[i].numpy(), y[i*m:(i+1)*m].numpy())
def test_split_grad(self):
n = 30000
m = 100
x = jt.random((n*m,))
arr = x.split(m)
yy = [ jt.rand(m) for i in range(n) ]
arr2 = [ y*yy[i] for i,y in enumerate(arr) ]
g = jt.grad(jt.concat(arr2), x)
for i in range(n):
np.testing.assert_allclose(g.data[i*m:(i+1)*m], yy[i].data)
for s in range(1000):
jt.set_global_seed(s)
data = check()
jt.gc()
jt.set_global_seed(s)
with jt.flag_scope(gopt_disable=1):
data2 = check()
jt.gc()
np.testing.assert_allclose(data, data2)
if __name__ == "__main__":

View File

@ -560,6 +560,15 @@ try:
except:
pass
try:
import sys
sys.setrecursionlimit(10**6)
if os.name != 'nt':
import resource
resource.setrlimit(resource.RLIMIT_STACK, (2**29,-1))
except:
pass
if os.name == 'nt':
if check_msvc_install:
if not os.path.isfile(cc_path):