mirror of https://github.com/Jittor/Jittor
optimize concat and split
This commit is contained in:
parent
54fb38caed
commit
9c74699707
|
@ -9,7 +9,7 @@
|
||||||
# file 'LICENSE.txt', which is part of this source code package.
|
# file 'LICENSE.txt', which is part of this source code package.
|
||||||
# ***************************************************************
|
# ***************************************************************
|
||||||
|
|
||||||
__version__ = '1.3.2.1'
|
__version__ = '1.3.2.2'
|
||||||
from jittor_utils import lock
|
from jittor_utils import lock
|
||||||
with lock.lock_scope():
|
with lock.lock_scope():
|
||||||
ori_int = int
|
ori_int = int
|
||||||
|
|
|
@ -231,6 +231,7 @@ def _merge_dtypes(dtypes):
|
||||||
dtype = names[s]+("" if e ==-1 else dbytes[e])
|
dtype = names[s]+("" if e ==-1 else dbytes[e])
|
||||||
return dtype
|
return dtype
|
||||||
|
|
||||||
|
@jt.flag_scope(amp_reg=4) # _custom_flag
|
||||||
def concat(arr, dim=0):
|
def concat(arr, dim=0):
|
||||||
'''Concat Operator can concat a list of jt Var at a specfic dimension.
|
'''Concat Operator can concat a list of jt Var at a specfic dimension.
|
||||||
|
|
||||||
|
|
|
@ -34,7 +34,7 @@ class Generator(nn.Module):
|
||||||
nn.Tanh())
|
nn.Tanh())
|
||||||
|
|
||||||
def execute(self, noise, labels):
|
def execute(self, noise, labels):
|
||||||
gen_input = jt.contrib.concat((self.label_emb(labels), noise), dim=1)
|
gen_input = jt.concat((self.label_emb(labels), noise), dim=1)
|
||||||
img = self.model(gen_input)
|
img = self.model(gen_input)
|
||||||
img = img.view((img.shape[0], *img_shape))
|
img = img.view((img.shape[0], *img_shape))
|
||||||
return img
|
return img
|
||||||
|
@ -55,7 +55,7 @@ class Discriminator(nn.Module):
|
||||||
nn.Linear(512, 1))
|
nn.Linear(512, 1))
|
||||||
|
|
||||||
def execute(self, img, labels):
|
def execute(self, img, labels):
|
||||||
d_in = jt.contrib.concat((img.view((img.shape[0], (- 1))), self.label_embedding(labels)), dim=1)
|
d_in = jt.concat((img.view((img.shape[0], (- 1))), self.label_embedding(labels)), dim=1)
|
||||||
validity = self.model(d_in)
|
validity = self.model(d_in)
|
||||||
return validity
|
return validity
|
||||||
|
|
||||||
|
|
|
@ -121,8 +121,12 @@ void CublasBatchedMatmulOp::jit_run() {
|
||||||
#if CUDART_VERSION >= 11000
|
#if CUDART_VERSION >= 11000
|
||||||
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
|
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
|
||||||
cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
|
cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
|
||||||
if (use_tensorcore) {
|
if (use_tensorcore>=3) {
|
||||||
computeType = CUBLAS_COMPUTE_32F_FAST_16F;
|
computeType = CUBLAS_COMPUTE_32F_FAST_16F;
|
||||||
|
} else if (use_tensorcore==2) {
|
||||||
|
computeType = CUBLAS_COMPUTE_32F_FAST_16BF;
|
||||||
|
} else if (use_tensorcore==1) {
|
||||||
|
computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
|
||||||
}
|
}
|
||||||
if (a->dtype() == ns_float16
|
if (a->dtype() == ns_float16
|
||||||
|| b->dtype() == ns_float16 || c->dtype() == ns_float16) {
|
|| b->dtype() == ns_float16 || c->dtype() == ns_float16) {
|
||||||
|
|
|
@ -78,8 +78,12 @@ void CublasMatmulOp::jit_run() {
|
||||||
#if CUDART_VERSION >= 11000
|
#if CUDART_VERSION >= 11000
|
||||||
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
|
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
|
||||||
cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
|
cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
|
||||||
if (use_tensorcore) {
|
if (use_tensorcore>=3) {
|
||||||
computeType = CUBLAS_COMPUTE_32F_FAST_16F;
|
computeType = CUBLAS_COMPUTE_32F_FAST_16F;
|
||||||
|
} else if (use_tensorcore==2) {
|
||||||
|
computeType = CUBLAS_COMPUTE_32F_FAST_16BF;
|
||||||
|
} else if (use_tensorcore==1) {
|
||||||
|
computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
|
||||||
}
|
}
|
||||||
if (a->dtype() == ns_float16
|
if (a->dtype() == ns_float16
|
||||||
|| b->dtype() == ns_float16 || c->dtype() == ns_float16) {
|
|| b->dtype() == ns_float16 || c->dtype() == ns_float16) {
|
||||||
|
|
|
@ -257,7 +257,7 @@ def stack(x, dim=0):
|
||||||
return x[0].unsqueeze(dim)
|
return x[0].unsqueeze(dim)
|
||||||
|
|
||||||
res = [x_.unsqueeze(dim) for x_ in x]
|
res = [x_.unsqueeze(dim) for x_ in x]
|
||||||
return jt.contrib.concat(res, dim=dim)
|
return jt.concat(res, dim=dim)
|
||||||
jt.Var.stack = stack
|
jt.Var.stack = stack
|
||||||
|
|
||||||
def flip(x, dim=0):
|
def flip(x, dim=0):
|
||||||
|
@ -342,7 +342,7 @@ def cross(input, other, dim=-1):
|
||||||
a1 = input[(slice(None,),)*dim+(1,)]*other[(slice(None,),)*dim+(2,)]-input[(slice(None,),)*dim+(2,)]*other[(slice(None,),)*dim+(1,)]
|
a1 = input[(slice(None,),)*dim+(1,)]*other[(slice(None,),)*dim+(2,)]-input[(slice(None,),)*dim+(2,)]*other[(slice(None,),)*dim+(1,)]
|
||||||
a2 = input[(slice(None,),)*dim+(2,)]*other[(slice(None,),)*dim+(0,)]-input[(slice(None,),)*dim+(0,)]*other[(slice(None,),)*dim+(2,)]
|
a2 = input[(slice(None,),)*dim+(2,)]*other[(slice(None,),)*dim+(0,)]-input[(slice(None,),)*dim+(0,)]*other[(slice(None,),)*dim+(2,)]
|
||||||
a3 = input[(slice(None,),)*dim+(0,)]*other[(slice(None,),)*dim+(1,)]-input[(slice(None,),)*dim+(1,)]*other[(slice(None,),)*dim+(0,)]
|
a3 = input[(slice(None,),)*dim+(0,)]*other[(slice(None,),)*dim+(1,)]-input[(slice(None,),)*dim+(1,)]*other[(slice(None,),)*dim+(0,)]
|
||||||
return jt.contrib.concat([a1.unsqueeze(dim),a2.unsqueeze(dim),a3.unsqueeze(dim)], dim=dim)
|
return jt.concat([a1.unsqueeze(dim),a2.unsqueeze(dim),a3.unsqueeze(dim)], dim=dim)
|
||||||
jt.Var.cross = cross
|
jt.Var.cross = cross
|
||||||
|
|
||||||
def normalize(input, p=2, dim=1, eps=1e-30):
|
def normalize(input, p=2, dim=1, eps=1e-30):
|
||||||
|
@ -465,7 +465,7 @@ def unique(x):
|
||||||
_,x = jt.argsort(x)
|
_,x = jt.argsort(x)
|
||||||
index,= jt.index((x.shape[0],))
|
index,= jt.index((x.shape[0],))
|
||||||
y = x[1:][x[index[1:]] != x[index[:-1]]]
|
y = x[1:][x[index[1:]] != x[index[:-1]]]
|
||||||
x = jt.contrib.concat([x[:1],y],dim=0)
|
x = jt.concat([x[:1],y],dim=0)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
jt.Var.unique = unique
|
jt.Var.unique = unique
|
||||||
|
@ -501,7 +501,7 @@ def nonzero(x):
|
||||||
x = [xx.unsqueeze(1) for xx in x]
|
x = [xx.unsqueeze(1) for xx in x]
|
||||||
if len(x)<2:
|
if len(x)<2:
|
||||||
return x[0]
|
return x[0]
|
||||||
x = jt.contrib.concat(x,dim=1)
|
x = jt.concat(x,dim=1)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
jt.Var.nonzero = nonzero
|
jt.Var.nonzero = nonzero
|
||||||
|
@ -546,7 +546,7 @@ def meshgrid(*tensors):
|
||||||
return grids
|
return grids
|
||||||
|
|
||||||
|
|
||||||
def split(d,split_size,dim):
|
def split(d, split_size, dim=0):
|
||||||
r'''
|
r'''
|
||||||
Splits the tensor into chunks. Each chunk is a view of the original tensor.
|
Splits the tensor into chunks. Each chunk is a view of the original tensor.
|
||||||
|
|
||||||
|
@ -575,7 +575,9 @@ def split(d,split_size,dim):
|
||||||
|
|
||||||
ans = []
|
ans = []
|
||||||
last = 0
|
last = 0
|
||||||
for i in split_size:
|
s_last = len(split_size)-1
|
||||||
|
gopt_disable = jt.flags.gopt_disable
|
||||||
|
for j, i in enumerate(split_size):
|
||||||
if i==0:
|
if i==0:
|
||||||
shape = list(d.shape)
|
shape = list(d.shape)
|
||||||
shape[dim]=0
|
shape[dim]=0
|
||||||
|
@ -584,7 +586,11 @@ def split(d,split_size,dim):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
ss = (slice(None),)*dim+(slice(last,last+i),)
|
ss = (slice(None),)*dim+(slice(last,last+i),)
|
||||||
new_d = d[ss]
|
if gopt_disable:
|
||||||
|
new_d = d.getitem(ss)
|
||||||
|
else:
|
||||||
|
new_d, d = d.getitem(ss, int(j==s_last))
|
||||||
|
|
||||||
last +=i
|
last +=i
|
||||||
ans.append(new_d)
|
ans.append(new_d)
|
||||||
return tuple(ans)
|
return tuple(ans)
|
||||||
|
|
|
@ -76,7 +76,7 @@ class _DenseLayer(nn.Sequential):
|
||||||
new_features = super(_DenseLayer, self).execute(x)
|
new_features = super(_DenseLayer, self).execute(x)
|
||||||
if (self.drop_rate > 0):
|
if (self.drop_rate > 0):
|
||||||
new_features = self.drop(new_features)
|
new_features = self.drop(new_features)
|
||||||
return jt.contrib.concat([x, new_features], dim=1)
|
return jt.concat([x, new_features], dim=1)
|
||||||
|
|
||||||
class _DenseBlock(nn.Sequential):
|
class _DenseBlock(nn.Sequential):
|
||||||
|
|
||||||
|
|
|
@ -121,7 +121,7 @@ class Inception(nn.Module):
|
||||||
|
|
||||||
def execute(self, x):
|
def execute(self, x):
|
||||||
outputs = self._forward(x)
|
outputs = self._forward(x)
|
||||||
return jt.contrib.concat(outputs, dim=1)
|
return jt.concat(outputs, dim=1)
|
||||||
|
|
||||||
class InceptionAux(nn.Module):
|
class InceptionAux(nn.Module):
|
||||||
|
|
||||||
|
|
|
@ -118,7 +118,7 @@ class InceptionA(nn.Module):
|
||||||
|
|
||||||
def execute(self, x):
|
def execute(self, x):
|
||||||
outputs = self._forward(x)
|
outputs = self._forward(x)
|
||||||
return jt.contrib.concat(outputs, dim=1)
|
return jt.concat(outputs, dim=1)
|
||||||
|
|
||||||
class InceptionB(nn.Module):
|
class InceptionB(nn.Module):
|
||||||
|
|
||||||
|
@ -142,7 +142,7 @@ class InceptionB(nn.Module):
|
||||||
|
|
||||||
def execute(self, x):
|
def execute(self, x):
|
||||||
outputs = self._forward(x)
|
outputs = self._forward(x)
|
||||||
return jt.contrib.concat(outputs, dim=1)
|
return jt.concat(outputs, dim=1)
|
||||||
|
|
||||||
class InceptionC(nn.Module):
|
class InceptionC(nn.Module):
|
||||||
|
|
||||||
|
@ -179,7 +179,7 @@ class InceptionC(nn.Module):
|
||||||
|
|
||||||
def execute(self, x):
|
def execute(self, x):
|
||||||
outputs = self._forward(x)
|
outputs = self._forward(x)
|
||||||
return jt.contrib.concat(outputs, dim=1)
|
return jt.concat(outputs, dim=1)
|
||||||
|
|
||||||
class InceptionD(nn.Module):
|
class InceptionD(nn.Module):
|
||||||
|
|
||||||
|
@ -207,7 +207,7 @@ class InceptionD(nn.Module):
|
||||||
|
|
||||||
def execute(self, x):
|
def execute(self, x):
|
||||||
outputs = self._forward(x)
|
outputs = self._forward(x)
|
||||||
return jt.contrib.concat(outputs, dim=1)
|
return jt.concat(outputs, dim=1)
|
||||||
|
|
||||||
class InceptionE(nn.Module):
|
class InceptionE(nn.Module):
|
||||||
|
|
||||||
|
@ -229,11 +229,11 @@ class InceptionE(nn.Module):
|
||||||
branch1x1 = self.branch1x1(x)
|
branch1x1 = self.branch1x1(x)
|
||||||
branch3x3 = self.branch3x3_1(x)
|
branch3x3 = self.branch3x3_1(x)
|
||||||
branch3x3 = [self.branch3x3_2a(branch3x3), self.branch3x3_2b(branch3x3)]
|
branch3x3 = [self.branch3x3_2a(branch3x3), self.branch3x3_2b(branch3x3)]
|
||||||
branch3x3 = jt.contrib.concat(branch3x3, dim=1)
|
branch3x3 = jt.concat(branch3x3, dim=1)
|
||||||
branch3x3dbl = self.branch3x3dbl_1(x)
|
branch3x3dbl = self.branch3x3dbl_1(x)
|
||||||
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
||||||
branch3x3dbl = [self.branch3x3dbl_3a(branch3x3dbl), self.branch3x3dbl_3b(branch3x3dbl)]
|
branch3x3dbl = [self.branch3x3dbl_3a(branch3x3dbl), self.branch3x3dbl_3b(branch3x3dbl)]
|
||||||
branch3x3dbl = jt.contrib.concat(branch3x3dbl, dim=1)
|
branch3x3dbl = jt.concat(branch3x3dbl, dim=1)
|
||||||
branch_pool = nn.pool(x, kernel_size=3, op="mean", stride=1, padding=1)
|
branch_pool = nn.pool(x, kernel_size=3, op="mean", stride=1, padding=1)
|
||||||
branch_pool = self.branch_pool(branch_pool)
|
branch_pool = self.branch_pool(branch_pool)
|
||||||
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
||||||
|
@ -241,7 +241,7 @@ class InceptionE(nn.Module):
|
||||||
|
|
||||||
def execute(self, x):
|
def execute(self, x):
|
||||||
outputs = self._forward(x)
|
outputs = self._forward(x)
|
||||||
return jt.contrib.concat(outputs, dim=1)
|
return jt.concat(outputs, dim=1)
|
||||||
|
|
||||||
class InceptionAux(nn.Module):
|
class InceptionAux(nn.Module):
|
||||||
|
|
||||||
|
|
|
@ -45,9 +45,9 @@ class InvertedResidual(nn.Module):
|
||||||
if (self.stride == 1):
|
if (self.stride == 1):
|
||||||
x1 = x[:,0:x.shape[1]//2]
|
x1 = x[:,0:x.shape[1]//2]
|
||||||
x2 = x[:,x.shape[1]//2:x.shape[1]]
|
x2 = x[:,x.shape[1]//2:x.shape[1]]
|
||||||
out = jt.contrib.concat([x1, self.branch2(x2)], dim=1)
|
out = jt.concat([x1, self.branch2(x2)], dim=1)
|
||||||
else:
|
else:
|
||||||
out = jt.contrib.concat([self.branch1(x), self.branch2(x)], dim=1)
|
out = jt.concat([self.branch1(x), self.branch2(x)], dim=1)
|
||||||
out = channel_shuffle(out, 2)
|
out = channel_shuffle(out, 2)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
|
@ -26,7 +26,7 @@ class Fire(nn.Module):
|
||||||
|
|
||||||
def execute(self, x):
|
def execute(self, x):
|
||||||
x = self.squeeze_activation(self.squeeze(x))
|
x = self.squeeze_activation(self.squeeze(x))
|
||||||
return jt.contrib.concat([self.expand1x1_activation(self.expand1x1(x)), self.expand3x3_activation(self.expand3x3(x))], dim=1)
|
return jt.concat([self.expand1x1_activation(self.expand1x1(x)), self.expand3x3_activation(self.expand3x3(x))], dim=1)
|
||||||
|
|
||||||
class SqueezeNet(nn.Module):
|
class SqueezeNet(nn.Module):
|
||||||
|
|
||||||
|
|
|
@ -80,7 +80,7 @@ class Generator(nn.Module):
|
||||||
nn.Tanh())
|
nn.Tanh())
|
||||||
|
|
||||||
def execute(self, noise, labels):
|
def execute(self, noise, labels):
|
||||||
gen_input = jt.contrib.concat((self.label_emb(labels), noise), dim=1)
|
gen_input = jt.concat((self.label_emb(labels), noise), dim=1)
|
||||||
img = self.model(gen_input)
|
img = self.model(gen_input)
|
||||||
img = img.view((img.shape[0], *img_shape))
|
img = img.view((img.shape[0], *img_shape))
|
||||||
return img
|
return img
|
||||||
|
@ -105,7 +105,7 @@ class Discriminator(nn.Module):
|
||||||
nn.Linear(512, 1))
|
nn.Linear(512, 1))
|
||||||
|
|
||||||
def execute(self, img, labels):
|
def execute(self, img, labels):
|
||||||
d_in = jt.contrib.concat((img.view((img.shape[0], (- 1))), self.label_embedding(labels)), dim=1)
|
d_in = jt.concat((img.view((img.shape[0], (- 1))), self.label_embedding(labels)), dim=1)
|
||||||
validity = self.model(d_in)
|
validity = self.model(d_in)
|
||||||
return validity
|
return validity
|
||||||
```
|
```
|
||||||
|
|
|
@ -31,6 +31,8 @@ pytype_map = {
|
||||||
"int": ["PyLong_AsLong", "PyLong_FromLong", "PyLong_CheckExact"],
|
"int": ["PyLong_AsLong", "PyLong_FromLong", "PyLong_CheckExact"],
|
||||||
"int64": ["PyLong_AsLongLong", "PyLong_FromLongLong", "PyLong_CheckExact"],
|
"int64": ["PyLong_AsLongLong", "PyLong_FromLongLong", "PyLong_CheckExact"],
|
||||||
"uint": ["PyLong_AsUnsignedLong", "PyLong_FromUnsignedLong", "PyLong_CheckExact"],
|
"uint": ["PyLong_AsUnsignedLong", "PyLong_FromUnsignedLong", "PyLong_CheckExact"],
|
||||||
|
"uint8": ["PyLong_AsUnsignedLong", "PyLong_FromUnsignedLong", "PyLong_CheckExact"],
|
||||||
|
"uint16": ["PyLong_AsUnsignedLong", "PyLong_FromUnsignedLong", "PyLong_CheckExact"],
|
||||||
"uint64": ["PyLong_AsUnsignedLongLong", "PyLong_FromUnsignedLongLong", "PyLong_CheckExact"],
|
"uint64": ["PyLong_AsUnsignedLongLong", "PyLong_FromUnsignedLongLong", "PyLong_CheckExact"],
|
||||||
"void": ["...", "GET_PY_NONE", "..."],
|
"void": ["...", "GET_PY_NONE", "..."],
|
||||||
"PyObject*": ["","",""],
|
"PyObject*": ["","",""],
|
||||||
|
|
|
@ -29,7 +29,7 @@ class SparseVar:
|
||||||
def t(self):
|
def t(self):
|
||||||
indices = list(self.indices.split(1,dim=0))
|
indices = list(self.indices.split(1,dim=0))
|
||||||
indices[-1],indices[-2] = indices[-2],indices[-1]
|
indices[-1],indices[-2] = indices[-2],indices[-1]
|
||||||
indices = jt.contrib.concat(indices,dim=0)
|
indices = jt.concat(indices,dim=0)
|
||||||
shape = list(self.shape)
|
shape = list(self.shape)
|
||||||
shape[-1],shape[-2] = shape[-2],shape[-1]
|
shape[-1],shape[-2] = shape[-2],shape[-1]
|
||||||
shape = jt.NanoVector(shape)
|
shape = jt.NanoVector(shape)
|
||||||
|
|
|
@ -132,4 +132,46 @@ void clean_graph() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void check_circle(Node* s) {
|
||||||
|
vector<Node*> q = {s};
|
||||||
|
vector<int> fa = {-1};
|
||||||
|
unordered_set<Node*> visited = {s};
|
||||||
|
for (int i=0; i<q.size(); i++) {
|
||||||
|
auto n = q[i];
|
||||||
|
for (auto o : n->outputs()) {
|
||||||
|
if (o == s) {
|
||||||
|
LOGe << "Found circle:";
|
||||||
|
int j=i;
|
||||||
|
vector<Node*> nodes{o};
|
||||||
|
while (j) {
|
||||||
|
nodes.push_back(q[j]);
|
||||||
|
j = fa[j];
|
||||||
|
}
|
||||||
|
for (int i=0; i<nodes.size(); i++) {
|
||||||
|
auto n = nodes[i];
|
||||||
|
auto out = nodes[(i-1+nodes.size())%nodes.size()];
|
||||||
|
auto in = nodes[(i+1)%nodes.size()];
|
||||||
|
int in_id=0, out_id=0;
|
||||||
|
for (auto ii : n->inputs()) {
|
||||||
|
if (ii == in) break;
|
||||||
|
in_id ++;
|
||||||
|
}
|
||||||
|
for (auto oo : n->outputs()) {
|
||||||
|
if (oo == out) break;
|
||||||
|
out_id ++;
|
||||||
|
}
|
||||||
|
LOGe << n << "in:" >> in_id >> '/' >> n->inputs().size() << "out:" >> out_id >> '/' >> n->outputs().size();
|
||||||
|
}
|
||||||
|
LOGf << "found circle";
|
||||||
|
}
|
||||||
|
if (!visited.count(o)) {
|
||||||
|
visited.emplace(o);
|
||||||
|
q.push_back(o);
|
||||||
|
fa.push_back(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
} // jittor
|
} // jittor
|
|
@ -148,4 +148,6 @@ void toplogical_sort_backward(vector<Node*>& nodes, vector<Node*>& sorted, Func&
|
||||||
ASSERTop(nodes.size(),==,sorted.size());
|
ASSERTop(nodes.size(),==,sorted.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void check_circle(Node* s);
|
||||||
|
|
||||||
} // jittor
|
} // jittor
|
|
@ -9,17 +9,6 @@
|
||||||
|
|
||||||
namespace jittor {
|
namespace jittor {
|
||||||
|
|
||||||
DEFINE_FLAG(int, amp_reg, 0, "Auto mixed-precision control registers, bit 0: prefer 32; bit 1: prefer 16; bit 2: keep reduce type; bit 3 keep white list type; bit 4: array like op prefer too");
|
|
||||||
|
|
||||||
DEFINE_FLAG_WITH_SETTER(int, auto_mixed_precision_level, 0, "Auto mixed-precision optimization level, 0: not use fp16, 1-3: preserve level, not use fp16 for now; 4: perfer fp16, but some ops use fp32 e.g. sum,exp; 5: simular with 4, and array op will automatically convert to fp16; 6: all ops prefer fp16");
|
|
||||||
|
|
||||||
void setter_auto_mixed_precision_level(int value) {
|
|
||||||
if (value <= 3) amp_reg = 0; else
|
|
||||||
if (value == 4) amp_reg = amp_prefer16; else
|
|
||||||
if (value == 5) amp_reg = amp_prefer16 | amp_array_prefer; else
|
|
||||||
if (value == 6) amp_reg = amp_prefer16 | amp_array_prefer | amp_keep_reduce | amp_keep_white;
|
|
||||||
}
|
|
||||||
|
|
||||||
#define FOR_ALL_TYPES(m) \
|
#define FOR_ALL_TYPES(m) \
|
||||||
m(bool) \
|
m(bool) \
|
||||||
m(int8) \
|
m(int8) \
|
||||||
|
|
|
@ -59,6 +59,7 @@ struct NodeFlags {
|
||||||
_prefer_16=_n+10,
|
_prefer_16=_n+10,
|
||||||
// bit11: reduce keep type unchange
|
// bit11: reduce keep type unchange
|
||||||
_reduce_keep=_n+11,
|
_reduce_keep=_n+11,
|
||||||
|
_custom_flag=_reduce_keep,
|
||||||
};
|
};
|
||||||
|
|
||||||
inline void set(Flags f, int a=1, int nbits=1) {
|
inline void set(Flags f, int a=1, int nbits=1) {
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
// ***************************************************************
|
// ***************************************************************
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include "var.h"
|
#include "var.h"
|
||||||
|
#include "executor.h"
|
||||||
#include "ops/getitem_op.h"
|
#include "ops/getitem_op.h"
|
||||||
#include "ops/op_register.h"
|
#include "ops/op_register.h"
|
||||||
#ifdef JIT_cuda
|
#ifdef JIT_cuda
|
||||||
|
@ -27,6 +28,8 @@ namespace jittor {
|
||||||
|
|
||||||
static auto make_number = get_op_info("number")
|
static auto make_number = get_op_info("number")
|
||||||
.get_constructor<VarPtr, float, Var*>();
|
.get_constructor<VarPtr, float, Var*>();
|
||||||
|
static auto make_empty = get_op_info("empty")
|
||||||
|
.get_constructor<VarPtr, NanoVector, NanoString>();
|
||||||
static auto make_setitem = get_op_info("setitem")
|
static auto make_setitem = get_op_info("setitem")
|
||||||
.get_constructor<VarPtr, Var*, VarSlices&&, Var*, NanoString>();
|
.get_constructor<VarPtr, Var*, VarSlices&&, Var*, NanoString>();
|
||||||
|
|
||||||
|
@ -38,6 +41,19 @@ GetitemOp::GetitemOp(Var* x, VarSlices&& slices)
|
||||||
create_output(nullptr, x->dtype());
|
create_output(nullptr, x->dtype());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
GetitemOp::GetitemOp(Var* x, VarSlices&& slices, int _)
|
||||||
|
: vs(move(slices)) {
|
||||||
|
flags.set(NodeFlags::_cpu);
|
||||||
|
flags.set(NodeFlags::_cuda);
|
||||||
|
flags.set(NodeFlags::_has_gopt);
|
||||||
|
flags.set(NodeFlags::_custom_flag);
|
||||||
|
flags.set(NodeFlags::_grads);
|
||||||
|
create_output(nullptr, x->dtype());
|
||||||
|
auto out2 = create_output(nullptr, x->dtype());
|
||||||
|
out2->share_with(x);
|
||||||
|
ns.data = _;
|
||||||
|
}
|
||||||
|
|
||||||
void GetitemOp::infer_slices(
|
void GetitemOp::infer_slices(
|
||||||
StackVector<>& __restrict__ i_to_vs,
|
StackVector<>& __restrict__ i_to_vs,
|
||||||
StackVector<>& __restrict__ i_to_o,
|
StackVector<>& __restrict__ i_to_o,
|
||||||
|
@ -377,6 +393,10 @@ void GetitemOp::infer_shape() {
|
||||||
this->i_to_vs = i_to_vs.to_nano_vector();
|
this->i_to_vs = i_to_vs.to_nano_vector();
|
||||||
this->i_to_o = i_to_o.to_nano_vector();
|
this->i_to_o = i_to_o.to_nano_vector();
|
||||||
this->o_shape = o_shape.to_nano_vector();
|
this->o_shape = o_shape.to_nano_vector();
|
||||||
|
if (outputs().size() > 1) {
|
||||||
|
auto out2 = output(1);
|
||||||
|
out2->set_shape(in->shape);
|
||||||
|
}
|
||||||
|
|
||||||
LOGV(999) << "\ni_to_vs:" << i_to_vs
|
LOGV(999) << "\ni_to_vs:" << i_to_vs
|
||||||
<< "\ni_to_o:" << i_to_o
|
<< "\ni_to_o:" << i_to_o
|
||||||
|
@ -396,6 +416,24 @@ VarPtr GetitemOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||||
return make_setitem(zeros, VarSlices(vs, true), dout, ns_void);
|
return make_setitem(zeros, VarSlices(vs, true), dout, ns_void);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void GetitemOp::grads(Var** dout, VarPtr* dins) {
|
||||||
|
VarPtr x = dout[1];
|
||||||
|
VarPtr y = dout[0];
|
||||||
|
if (!x) {
|
||||||
|
auto in = inputs().front();
|
||||||
|
if (in->num<0) exe.run_sync(vector<Var*>({in}), true);
|
||||||
|
// ns.data represents this is the last split var
|
||||||
|
if (ns.data)
|
||||||
|
x = make_empty(in->shape, in->dtype());
|
||||||
|
else
|
||||||
|
x = make_number(0, in);
|
||||||
|
}
|
||||||
|
if (!y) {
|
||||||
|
y = make_number(0, outputs().front());
|
||||||
|
}
|
||||||
|
dins[0] = make_setitem(x, VarSlices(vs, true), y, ns_void);
|
||||||
|
}
|
||||||
|
|
||||||
void GetitemOp::jit_prepare(JK& jk) {
|
void GetitemOp::jit_prepare(JK& jk) {
|
||||||
auto in = inputs().front();
|
auto in = inputs().front();
|
||||||
int idim = i_to_vs.size();
|
int idim = i_to_vs.size();
|
||||||
|
|
|
@ -22,9 +22,12 @@ struct GetitemOp : Op {
|
||||||
int first_oid_of_var, var_dim;
|
int first_oid_of_var, var_dim;
|
||||||
|
|
||||||
GetitemOp(Var* x, VarSlices&& slices);
|
GetitemOp(Var* x, VarSlices&& slices);
|
||||||
|
// @attrs(multiple_outputs)
|
||||||
|
GetitemOp(Var* x, VarSlices&& slices, int _);
|
||||||
|
|
||||||
const char* name() const override { return "getitem"; }
|
const char* name() const override { return "getitem"; }
|
||||||
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
|
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
|
||||||
|
void grads(Var** dout, VarPtr* dins) override;
|
||||||
void infer_shape() override;
|
void infer_shape() override;
|
||||||
void compile_optimize(string& src) override;
|
void compile_optimize(string& src) override;
|
||||||
void graph_optimize() override;
|
void graph_optimize() override;
|
||||||
|
|
|
@ -28,6 +28,8 @@ static auto make_array = get_op_info("array")
|
||||||
.get_constructor<VarPtr, const void*, NanoVector, NanoString>();
|
.get_constructor<VarPtr, const void*, NanoVector, NanoString>();
|
||||||
static auto make_getitem = get_op_info("getitem")
|
static auto make_getitem = get_op_info("getitem")
|
||||||
.get_constructor<VarPtr, Var*, VarSlices&&>();
|
.get_constructor<VarPtr, Var*, VarSlices&&>();
|
||||||
|
static auto make_getitem2 = get_op_info("getitem")
|
||||||
|
.get_constructor<vector<VarPtr>, Var*, VarSlices&&, int>();
|
||||||
static auto make_setitem = get_op_info("setitem")
|
static auto make_setitem = get_op_info("setitem")
|
||||||
.get_constructor<VarPtr, Var*, VarSlices&&, Var*, NanoString>();
|
.get_constructor<VarPtr, Var*, VarSlices&&, Var*, NanoString>();
|
||||||
static auto make_binary = get_op_info("binary")
|
static auto make_binary = get_op_info("binary")
|
||||||
|
@ -40,8 +42,11 @@ SetitemOp::SetitemOp(Var* x, VarSlices&& slices, Var* y, NanoString op)
|
||||||
flags.set(NodeFlags::_cpu);
|
flags.set(NodeFlags::_cpu);
|
||||||
flags.set(NodeFlags::_cuda);
|
flags.set(NodeFlags::_cuda);
|
||||||
flags.set(NodeFlags::_has_gopt);
|
flags.set(NodeFlags::_has_gopt);
|
||||||
ASSERT(ns == ns_void || ns.is_binary());
|
ASSERT(op == ns_void || op.is_binary());
|
||||||
create_output(nullptr, x->dtype());
|
create_output(nullptr, x->dtype());
|
||||||
|
if (flags.get(NodeFlags::_custom_flag)) {
|
||||||
|
flags.set(NodeFlags::_grads);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetitemOp::infer_shape() {
|
void SetitemOp::infer_shape() {
|
||||||
|
@ -120,6 +125,12 @@ void SetitemOp::infer_shape() {
|
||||||
<< "\no_shape:" << o_shape;
|
<< "\no_shape:" << o_shape;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void SetitemOp::grads(Var** dout, VarPtr* dins) {
|
||||||
|
auto outs = make_getitem2(dout[0], VarSlices(vs, true), 0);
|
||||||
|
dins[0] = move(outs[1]);
|
||||||
|
dins[1] = move(outs[0]);
|
||||||
|
}
|
||||||
|
|
||||||
VarPtr SetitemOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
VarPtr SetitemOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||||
if (v_index >= 2)
|
if (v_index >= 2)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
|
@ -25,6 +25,7 @@ struct SetitemOp : Op {
|
||||||
|
|
||||||
const char* name() const override { return "setitem"; }
|
const char* name() const override { return "setitem"; }
|
||||||
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
|
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
|
||||||
|
void grads(Var** dout, VarPtr* dins) override;
|
||||||
void infer_shape() override;
|
void infer_shape() override;
|
||||||
void compile_optimize(string& src) override;
|
void compile_optimize(string& src) override;
|
||||||
void graph_optimize() override;
|
void graph_optimize() override;
|
||||||
|
|
|
@ -73,6 +73,9 @@ static void setitem_inplace(SetitemOp* op) {
|
||||||
return;
|
return;
|
||||||
if (data->allocator)
|
if (data->allocator)
|
||||||
return;
|
return;
|
||||||
|
auto data_op = data->input();
|
||||||
|
if (data_op->flags.get(NodeFlags::_custom_flag))
|
||||||
|
return;
|
||||||
|
|
||||||
auto in_shape = input->shape;
|
auto in_shape = input->shape;
|
||||||
int64 inplace_size = 1;
|
int64 inplace_size = 1;
|
||||||
|
|
|
@ -22,6 +22,16 @@ DEFINE_FLAG(bool, no_grad, 0,
|
||||||
"No grad for all jittor Var creation");
|
"No grad for all jittor Var creation");
|
||||||
DEFINE_FLAG(bool, no_fuse, 0,
|
DEFINE_FLAG(bool, no_fuse, 0,
|
||||||
"No fusion optimization for all jittor Var creation");
|
"No fusion optimization for all jittor Var creation");
|
||||||
|
DEFINE_FLAG(int, amp_reg, 0, "Auto mixed-precision control registers, bit 0: prefer 32; bit 1: prefer 16; bit 2: keep reduce type; bit 3 keep white list type; bit 4: array like op prefer too");
|
||||||
|
|
||||||
|
DEFINE_FLAG_WITH_SETTER(int, auto_mixed_precision_level, 0, "Auto mixed-precision optimization level, 0: not use fp16, 1-3: preserve level, not use fp16 for now; 4: perfer fp16, but some ops use fp32 e.g. sum,exp; 5: simular with 4, and array op will automatically convert to fp16; 6: all ops prefer fp16");
|
||||||
|
|
||||||
|
void setter_auto_mixed_precision_level(int value) {
|
||||||
|
if (value <= 3) amp_reg = 0; else
|
||||||
|
if (value == 4) amp_reg = amp_prefer16; else
|
||||||
|
if (value == 5) amp_reg = amp_prefer16 | amp_array_prefer; else
|
||||||
|
if (value == 6) amp_reg = amp_prefer16 | amp_array_prefer | amp_keep_reduce | amp_keep_white;
|
||||||
|
}
|
||||||
|
|
||||||
Var::Var(NanoVector shape, NanoString dtype)
|
Var::Var(NanoVector shape, NanoString dtype)
|
||||||
: shape(shape),
|
: shape(shape),
|
||||||
|
|
|
@ -48,7 +48,7 @@ class TestConcatOp(unittest.TestCase):
|
||||||
def test_concat_op(self):
|
def test_concat_op(self):
|
||||||
def check(tmp, dim=0):
|
def check(tmp, dim=0):
|
||||||
res1 = numpy_concat(tmp, dim=dim)
|
res1 = numpy_concat(tmp, dim=dim)
|
||||||
res2 = jt.contrib.concat(tmp, dim=dim)
|
res2 = jt.concat(tmp, dim=dim)
|
||||||
assert (res2!=res1).data.sum()==0, "concat fail..."
|
assert (res2!=res1).data.sum()==0, "concat fail..."
|
||||||
check([jt.array([[1],[2]]), jt.array([[2],[2]])])
|
check([jt.array([[1],[2]]), jt.array([[2],[2]])])
|
||||||
check([jt.array(np.array(range(24))).reshape((1,2,3,4)), jt.array(np.array(range(24))).reshape((1,2,3,4))])
|
check([jt.array(np.array(range(24))).reshape((1,2,3,4)), jt.array(np.array(range(24))).reshape((1,2,3,4))])
|
||||||
|
@ -68,7 +68,7 @@ class TestConcatOp(unittest.TestCase):
|
||||||
arr = []
|
arr = []
|
||||||
for i in range(m):
|
for i in range(m):
|
||||||
arr.append(a[(slice(None),)*dim + (slice(i*size,i*size+size),)])
|
arr.append(a[(slice(None),)*dim + (slice(i*size,i*size+size),)])
|
||||||
b = jt.contrib.concat(arr, dim)
|
b = jt.concat(arr, dim)
|
||||||
if backward:
|
if backward:
|
||||||
loss = b * a
|
loss = b * a
|
||||||
b = jt.grad(loss, a)
|
b = jt.grad(loss, a)
|
||||||
|
|
|
@ -20,7 +20,7 @@ class TestContrib(unittest.TestCase):
|
||||||
arr1.append(a)
|
arr1.append(a)
|
||||||
arr2.append(jt.array(a))
|
arr2.append(jt.array(a))
|
||||||
x = np.concatenate(tuple(arr1), dim)
|
x = np.concatenate(tuple(arr1), dim)
|
||||||
y = jt.contrib.concat(arr2, dim)
|
y = jt.concat(arr2, dim)
|
||||||
assert (x==y.data).all(), (x, y.data, arr1, arr2)
|
assert (x==y.data).all(), (x, y.data, arr1, arr2)
|
||||||
check([2,3,4], 0, 2)
|
check([2,3,4], 0, 2)
|
||||||
check([2,3,4], 1, 3)
|
check([2,3,4], 1, 3)
|
||||||
|
|
|
@ -104,7 +104,7 @@ class TestSingleArray(unittest.TestCase):
|
||||||
arr1.append(a)
|
arr1.append(a)
|
||||||
arr2.append(jt.array(a))
|
arr2.append(jt.array(a))
|
||||||
x = np.concatenate(tuple(arr1), dim)
|
x = np.concatenate(tuple(arr1), dim)
|
||||||
y = jt.contrib.concat(arr2, dim)
|
y = jt.concat(arr2, dim)
|
||||||
assert (x==y.data).all()
|
assert (x==y.data).all()
|
||||||
check([1], 0, 20)
|
check([1], 0, 20)
|
||||||
|
|
||||||
|
|
|
@ -39,7 +39,7 @@ class TestSetitem(unittest.TestCase):
|
||||||
|
|
||||||
arr21 = jt.ones((2,2))
|
arr21 = jt.ones((2,2))
|
||||||
arr22 = jt.ones((2,2)) * 2
|
arr22 = jt.ones((2,2)) * 2
|
||||||
arr2 = jt.contrib.concat([arr21, arr22], dim=0)
|
arr2 = jt.concat([arr21, arr22], dim=0)
|
||||||
arr2.sync()
|
arr2.sync()
|
||||||
arr21.data[0,0] = 3
|
arr21.data[0,0] = 3
|
||||||
arr22.data[0,0] = 4
|
arr22.data[0,0] = 4
|
||||||
|
@ -283,14 +283,17 @@ class TestSetitem(unittest.TestCase):
|
||||||
np.concatenate([b.data,c.data]))
|
np.concatenate([b.data,c.data]))
|
||||||
|
|
||||||
def test_concat_random(self):
|
def test_concat_random(self):
|
||||||
def check():
|
def check(backward=False):
|
||||||
n1, n2, n3 = 1000, 20, 10
|
n1, n2, n3 = 1000, 20, 10
|
||||||
# n1, n2, n3 = 2, 2, 1
|
# n1, n2, n3 = 3, 2, 3
|
||||||
import random
|
import random
|
||||||
data = []
|
data = []
|
||||||
|
back = []
|
||||||
for i in range(n1):
|
for i in range(n1):
|
||||||
if len(data) > n2:
|
if len(data) > n2:
|
||||||
del data[random.randint(0,len(data)-1)]
|
v = random.randint(0,len(data)-1)
|
||||||
|
# print("del", v)
|
||||||
|
del data[v]
|
||||||
x1 = random.randint(0,9)
|
x1 = random.randint(0,9)
|
||||||
# print(i, x1)
|
# print(i, x1)
|
||||||
if len(data) == 0:
|
if len(data) == 0:
|
||||||
|
@ -319,7 +322,15 @@ class TestSetitem(unittest.TestCase):
|
||||||
elif x1 == 4:
|
elif x1 == 4:
|
||||||
# a = jt.random((random.randint(10,20),))
|
# a = jt.random((random.randint(10,20),))
|
||||||
a = jt.array(np.random.rand(random.randint(n3,n3*2)))
|
a = jt.array(np.random.rand(random.randint(n3,n3*2)))
|
||||||
|
if backward and random.randint(0,1):
|
||||||
|
back.append(a)
|
||||||
data.append(a)
|
data.append(a)
|
||||||
|
elif x1 == 5:
|
||||||
|
v = random.randint(0,len(data)-1)
|
||||||
|
a = data[v]
|
||||||
|
# print("split", v, a.shape)
|
||||||
|
arr = a.split(n3-1)
|
||||||
|
data += arr
|
||||||
else:
|
else:
|
||||||
if not len(data): continue
|
if not len(data): continue
|
||||||
n = random.randint(1,3)
|
n = random.randint(1,3)
|
||||||
|
@ -329,19 +340,49 @@ class TestSetitem(unittest.TestCase):
|
||||||
b = np.random.permutation(np.arange(a.numel()))
|
b = np.random.permutation(np.arange(a.numel()))
|
||||||
a = a[b][:100]
|
a = a[b][:100]
|
||||||
data.append(a)
|
data.append(a)
|
||||||
ret = jt.concat(data).numpy()
|
ret = jt.concat(data)
|
||||||
# print(data)
|
if backward and len(back):
|
||||||
return ret
|
grads = jt.grad(jt.rand_like(ret)*ret, back)
|
||||||
|
return jt.concat(grads).numpy()
|
||||||
|
return ret.numpy()
|
||||||
|
|
||||||
|
for s in range(100):
|
||||||
|
print("check", s)
|
||||||
|
for check_grad in [True, False]:
|
||||||
|
jt.set_global_seed(s)
|
||||||
|
data = check(check_grad)
|
||||||
|
jt.gc()
|
||||||
|
jt.set_global_seed(s)
|
||||||
|
with jt.flag_scope(gopt_disable=1):
|
||||||
|
data2 = check(check_grad)
|
||||||
|
jt.gc()
|
||||||
|
np.testing.assert_allclose(data, data2, atol=1e-5, rtol=1e-5)
|
||||||
|
|
||||||
|
def test_concat_grad(self):
|
||||||
|
n = 30000
|
||||||
|
m = 100
|
||||||
|
arr = []
|
||||||
|
for i in range(n):
|
||||||
|
arr.append(jt.random((m,)))
|
||||||
|
x = jt.concat(arr)
|
||||||
|
y = jt.rand_like(x)
|
||||||
|
grads = jt.grad(x*y, arr)
|
||||||
|
for i in range(n):
|
||||||
|
np.testing.assert_allclose(grads[i].numpy(), y[i*m:(i+1)*m].numpy())
|
||||||
|
|
||||||
|
def test_split_grad(self):
|
||||||
|
n = 30000
|
||||||
|
m = 100
|
||||||
|
x = jt.random((n*m,))
|
||||||
|
arr = x.split(m)
|
||||||
|
yy = [ jt.rand(m) for i in range(n) ]
|
||||||
|
arr2 = [ y*yy[i] for i,y in enumerate(arr) ]
|
||||||
|
g = jt.grad(jt.concat(arr2), x)
|
||||||
|
for i in range(n):
|
||||||
|
np.testing.assert_allclose(g.data[i*m:(i+1)*m], yy[i].data)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
for s in range(1000):
|
|
||||||
jt.set_global_seed(s)
|
|
||||||
data = check()
|
|
||||||
jt.gc()
|
|
||||||
jt.set_global_seed(s)
|
|
||||||
with jt.flag_scope(gopt_disable=1):
|
|
||||||
data2 = check()
|
|
||||||
jt.gc()
|
|
||||||
np.testing.assert_allclose(data, data2)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -560,6 +560,15 @@ try:
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
import sys
|
||||||
|
sys.setrecursionlimit(10**6)
|
||||||
|
if os.name != 'nt':
|
||||||
|
import resource
|
||||||
|
resource.setrlimit(resource.RLIMIT_STACK, (2**29,-1))
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
if os.name == 'nt':
|
if os.name == 'nt':
|
||||||
if check_msvc_install:
|
if check_msvc_install:
|
||||||
if not os.path.isfile(cc_path):
|
if not os.path.isfile(cc_path):
|
||||||
|
|
Loading…
Reference in New Issue