mirror of https://github.com/Jittor/Jittor
Merge branch 'master' of https://github.com/Jittor/jittor
This commit is contained in:
commit
91f984ee3d
|
@ -24,6 +24,7 @@
|
|||
jittor.contrib
|
||||
jittor.dataset
|
||||
jittor.transform
|
||||
jittor.mpi
|
||||
|
||||
|
||||
.. toctree::
|
||||
|
|
|
@ -6,8 +6,9 @@ jittor.models
|
|||
```eval_rst
|
||||
|
||||
.. automodule:: jittor.models
|
||||
:members:
|
||||
:members:
|
||||
:imported-members:
|
||||
:undoc-members:
|
||||
:exclude-members: ResNet,ShuffleNetV2,SqueezeNet,VGG
|
||||
```
|
||||
|
||||
|
|
|
@ -0,0 +1,13 @@
|
|||
jittor.mpi
|
||||
=====================
|
||||
|
||||
这里是Jittor的MPI模块的API文档,您可以通过`from jittor import mpi`来获取该模块。
|
||||
|
||||
```eval_rst
|
||||
.. automodule:: jittor_mpi_core
|
||||
:members:
|
||||
:undoc-members:
|
||||
.. automodule:: jittor_mpi_core.ops
|
||||
:members:
|
||||
:undoc-members:
|
||||
```
|
|
@ -20,7 +20,7 @@
|
|||
namespace jittor {
|
||||
|
||||
#ifndef JIT
|
||||
CubArgReduceOp::CubArgReduceOp(Var* x, Var* offsets, string op, bool keepdims)
|
||||
CubArgReduceOp::CubArgReduceOp(Var* x, Var* offsets, NanoString op, bool keepdims)
|
||||
: x(x), offsets(offsets), op(op), keepdims(keepdims) {
|
||||
flags.set(NodeFlags::_cpu, 0);
|
||||
flags.set(NodeFlags::_cuda, 1);
|
||||
|
@ -56,7 +56,7 @@ void CubArgReduceOp::infer_shape() {
|
|||
void CubArgReduceOp::jit_prepare() {
|
||||
add_jit_define("Tx", x->dtype());
|
||||
add_jit_define("Toffsets", offsets->dtype());
|
||||
add_jit_define("FUNC", op=="min" ? "ArgMin" : "ArgMax");
|
||||
add_jit_define("FUNC", op==ns_minimum ? "ArgMin" : "ArgMax");
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -14,10 +14,10 @@ namespace jittor {
|
|||
|
||||
struct CubArgReduceOp : Op {
|
||||
Var* x, * offsets, * y, * y_key;
|
||||
string op;
|
||||
NanoString op;
|
||||
bool keepdims;
|
||||
// @attrs(multiple_outputs)
|
||||
CubArgReduceOp(Var* x, Var* offsets, string op, bool keepdims);
|
||||
CubArgReduceOp(Var* x, Var* offsets, NanoString op, bool keepdims);
|
||||
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
|
||||
void infer_shape() override;
|
||||
|
||||
|
|
|
@ -29,18 +29,33 @@ extern int mpi_world_rank;
|
|||
extern int mpi_local_rank;
|
||||
extern bool inside_mpi;
|
||||
|
||||
/**
|
||||
Return number of MPI nodes.
|
||||
*/
|
||||
// @pyjt(world_size)
|
||||
int _mpi_world_size();
|
||||
|
||||
/**
|
||||
Return global ID of this MPI node.
|
||||
*/
|
||||
// @pyjt(world_rank)
|
||||
int _mpi_world_rank();
|
||||
|
||||
/**
|
||||
Return local ID of this MPI node.
|
||||
*/
|
||||
// @pyjt(local_rank)
|
||||
int _mpi_local_rank();
|
||||
|
||||
struct ArrayArgs;
|
||||
|
||||
/**
|
||||
|
||||
Use jt.Module.mpi_param_broadcast(root=0) to broadcast all moudule parameters of this module in [root] MPI node to all MPI nodes.
|
||||
|
||||
This operation has no gradient, and the input parameter type is numpy array.
|
||||
*/
|
||||
// @pyjt(broadcast)
|
||||
void _mpi_broadcast(ArrayArgs&& args, int i);
|
||||
void _mpi_broadcast(ArrayArgs&& args, int root);
|
||||
|
||||
} // jittor
|
||||
|
|
|
@ -15,6 +15,15 @@ struct MpiAllReduceOp : Op {
|
|||
Var* x, * y;
|
||||
NanoString op;
|
||||
|
||||
/**
|
||||
|
||||
Mpi All Reduce Operator uses the operator [op] to reduce variable [x] in all MPI nodes and broadcast to all MPI nodes.
|
||||
|
||||
Args:
|
||||
|
||||
* x: variable to be all reduced.
|
||||
* op: 'sum' or 'add' means sum all [x], 'mean' means average all [x]. Default: 'add'.
|
||||
*/
|
||||
MpiAllReduceOp(Var* x, NanoString op=ns_add);
|
||||
void infer_shape() override;
|
||||
|
||||
|
|
|
@ -15,6 +15,15 @@ struct MpiBroadcastOp : Op {
|
|||
Var* x, * y;
|
||||
int root;
|
||||
|
||||
/**
|
||||
|
||||
Mpi Broadcast Operator broadcasts variable [x] in [root] MPI nodes to all MPI nodes.
|
||||
|
||||
Args:
|
||||
|
||||
* x: variable to be broadcasted.
|
||||
* root: ID of MPI node to be broadcasted. Default: 0.
|
||||
*/
|
||||
MpiBroadcastOp(Var* x, int root=0);
|
||||
void infer_shape() override;
|
||||
|
||||
|
|
|
@ -16,6 +16,16 @@ struct MpiReduceOp : Op {
|
|||
NanoString op;
|
||||
int root;
|
||||
|
||||
/**
|
||||
|
||||
Mpi Reduce Operator uses the operator [op] to reduce variable [x] in all MPI nodes and send to the [root] MPI node.
|
||||
|
||||
Args:
|
||||
|
||||
* x: variable to be reduced.
|
||||
* op: 'sum' or 'add' means sum all [x], 'mean' means average all [x]. Default: 'add'.
|
||||
* root: ID of MPI node to output. Default: 0.
|
||||
*/
|
||||
MpiReduceOp(Var* x, NanoString op=ns_add, int root=0);
|
||||
void infer_shape() override;
|
||||
|
||||
|
|
|
@ -44,11 +44,11 @@ int _mpi_local_rank() {
|
|||
return mpi_local_rank;
|
||||
}
|
||||
|
||||
void _mpi_broadcast(ArrayArgs&& args, int i) {
|
||||
void _mpi_broadcast(ArrayArgs&& args, int root) {
|
||||
int64 size = args.dtype.dsize();
|
||||
for (auto j : args.shape)
|
||||
size *= j;
|
||||
MPI_CHECK(MPI_Bcast((void *)args.ptr, size, MPI_BYTE, i, MPI_COMM_WORLD));
|
||||
MPI_CHECK(MPI_Bcast((void *)args.ptr, size, MPI_BYTE, root, MPI_COMM_WORLD));
|
||||
}
|
||||
|
||||
static uint64_t getHostHash(const char* string) {
|
||||
|
|
|
@ -0,0 +1,170 @@
|
|||
# 使用Jittor实现Conditional GAN
|
||||
|
||||
Generative Adversarial Nets(GAN)[1]提出了一种新的方法来训练生成模型。然而,GAN对于要生成的图片缺少控制。Conditional GAN(CGAN)[2]通过添加显式的条件或标签,来控制生成的图像。本教程讲解了CGAN的网络结构、损失函数设计、使用CGAN生成一串数字、从头训练CGAN、以及在mnist手写数字数据集上的训练结果。
|
||||
|
||||
## CGAN网络架构
|
||||
|
||||
通过在生成器generator和判别器discriminator中添加相同的额外信息y,GAN就可以扩展为一个conditional模型。y可以是任何形式的辅助信息,例如类别标签或者其他形式的数据。我们可以通过将y作为额外输入层,添加到生成器和判别器来完成条件控制。
|
||||
|
||||
在生成器generator中,除了y之外,还额外输入随机一维噪声z,为结果生成提供更多灵活性。
|
||||
|
||||

|
||||
|
||||
## 损失函数
|
||||
|
||||
### GAN的损失函数
|
||||
|
||||
在解释CGAN的损失函数之前,首先介绍GAN的损失函数。下面是GAN的损失函数设计。
|
||||
|
||||

|
||||
|
||||
对于判别器D,我们要训练最大化这个loss。如果D的输入是来自真实样本的数据x,则D的输出D(x)要尽可能地大,log(D(x))也会尽可能大。如果D的输入是来自G生成的假图片G(z),则D的输出D(G(z))应尽可能地小,从而log(1-D(G(z))会尽可能地大。这样可以达到max D的目的。
|
||||
|
||||
对于生成器G,我们要训练最小化这个loss。对于G生成的假图片G(z),我们希望尽可能地骗过D,让它觉得我们生成的图片就是真的图片,这样就达到了G“以假乱真”的目的。那么D的输出D(G(z))应尽可能地大,从而log(1-D(G(z))会尽可能地小。这样可以达到min G的目的。
|
||||
|
||||
D和G以这样的方式联合训练,最终达到G的生成能力越来越强,D的判别能力越来越强的目的。
|
||||
|
||||
### CGAN的损失函数
|
||||
|
||||
下面是CGAN的损失函数设计。
|
||||
|
||||

|
||||
|
||||
|
||||
很明显,CGAN的loss跟GAN的loss的区别就是多了条件限定y。D(x/y)代表在条件y下,x为真的概率。D(G(z/y))表示在条件y下,G生成的图片被D判别为真的概率。
|
||||
|
||||
## Jittor代码数字生成
|
||||
|
||||
首先,我们导入需要的包,并且设置好所需的超参数:
|
||||
|
||||
```python
|
||||
import jittor as jt
|
||||
from jittor import nn
|
||||
import numpy as np
|
||||
import pylab as pl
|
||||
|
||||
%matplotlib inline
|
||||
|
||||
# 隐空间向量长度
|
||||
latent_dim = 100
|
||||
# 类别数量
|
||||
n_classes = 10
|
||||
# 图片大小
|
||||
img_size = 32
|
||||
# 图片通道数量
|
||||
channels = 1
|
||||
# 图片张量的形状
|
||||
img_shape = (channels, img_size, img_size)
|
||||
```
|
||||
|
||||
第一步,定义生成器G。该生成器输入两个一维向量y和noise,生成一张图片。
|
||||
|
||||
```python
|
||||
class Generator(nn.Module):
|
||||
def __init__(self):
|
||||
super(Generator, self).__init__()
|
||||
self.label_emb = nn.Embedding(n_classes, n_classes)
|
||||
|
||||
def block(in_feat, out_feat, normalize=True):
|
||||
layers = [nn.Linear(in_feat, out_feat)]
|
||||
if normalize:
|
||||
layers.append(nn.BatchNorm1d(out_feat, 0.8))
|
||||
layers.append(nn.LeakyReLU(0.2))
|
||||
return layers
|
||||
self.model = nn.Sequential(
|
||||
*block((latent_dim + n_classes), 128, normalize=False),
|
||||
*block(128, 256),
|
||||
*block(256, 512),
|
||||
*block(512, 1024),
|
||||
nn.Linear(1024, int(np.prod(img_shape))),
|
||||
nn.Tanh())
|
||||
|
||||
def execute(self, noise, labels):
|
||||
gen_input = jt.contrib.concat((self.label_emb(labels), noise), dim=1)
|
||||
img = self.model(gen_input)
|
||||
img = img.view((img.shape[0], *img_shape))
|
||||
return img
|
||||
```
|
||||
|
||||
第二步,定义判别器D。D输入一张图片和对应的y,输出是真图片的概率。
|
||||
|
||||
```python
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self):
|
||||
super(Discriminator, self).__init__()
|
||||
self.label_embedding = nn.Embedding(n_classes, n_classes)
|
||||
self.model = nn.Sequential(
|
||||
nn.Linear((n_classes + int(np.prod(img_shape))), 512),
|
||||
nn.LeakyReLU(0.2),
|
||||
nn.Linear(512, 512),
|
||||
nn.Dropout(0.4),
|
||||
nn.LeakyReLU(0.2),
|
||||
nn.Linear(512, 512),
|
||||
nn.Dropout(0.4),
|
||||
nn.LeakyReLU(0.2),
|
||||
nn.Linear(512, 1))
|
||||
|
||||
def execute(self, img, labels):
|
||||
d_in = jt.contrib.concat((img.view((img.shape[0], (- 1))), self.label_embedding(labels)), dim=1)
|
||||
validity = self.model(d_in)
|
||||
return validity
|
||||
```
|
||||
|
||||
第三步,使用CGAN生成一串数字。
|
||||
|
||||
代码如下。您可以使用您训练好的模型来生成图片,也可以使用我们提供的预训练参数: 模型预训练参数下载:<https://cloud.tsinghua.edu.cn/d/fbe30ae0967942f6991c/>。
|
||||
|
||||
```python
|
||||
# 下载提供的预训练参数
|
||||
!wget https://cg.cs.tsinghua.edu.cn/jittor/assets/build/generator_last.pkl
|
||||
!wget https://cg.cs.tsinghua.edu.cn/jittor/assets/build/discriminator_last.pkl
|
||||
```
|
||||
|
||||
生成自定义的数字:
|
||||
|
||||
```python
|
||||
# 定义模型
|
||||
generator = Generator()
|
||||
discriminator = Discriminator()
|
||||
generator.eval()
|
||||
discriminator.eval()
|
||||
|
||||
# 加载参数
|
||||
generator.load('./generator_last.pkl')
|
||||
discriminator.load('./discriminator_last.pkl')
|
||||
|
||||
# 定义一串数字
|
||||
number = "201962517"
|
||||
n_row = len(number)
|
||||
z = jt.array(np.random.normal(0, 1, (n_row, latent_dim))).float32().stop_grad()
|
||||
labels = jt.array(np.array([int(number[num]) for num in range(n_row)])).float32().stop_grad()
|
||||
gen_imgs = generator(z,labels)
|
||||
|
||||
pl.imshow(gen_imgs.data.transpose((1,2,0,3))[0].reshape((gen_imgs.shape[2], -1)))
|
||||
```
|
||||
|
||||
## 从头训练Condition GAN
|
||||
|
||||
从头训练 Condition GAN 的完整代码在<https://github.com/Jittor/gan-jittor/blob/master/models/cgan/cgan.py>, 让我们把他下载下来看看!
|
||||
|
||||
```python
|
||||
!wget https://raw.githubusercontent.com/Jittor/gan-jittor/master/models/cgan/cgan.py
|
||||
!python3.7 ./cgan.py --help
|
||||
|
||||
# 选择合适的batch size,运行试试
|
||||
# 运行命令: !python3.7 ./cgan.py --batch_size 64
|
||||
```
|
||||
|
||||
## MNIST数据集训练结果
|
||||
|
||||
下面展示了Jittor版CGAN在MNIST数据集的训练结果。下面分别是训练0 epoch和90 epoches的结果。
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
## 参考文献
|
||||
|
||||
1. Goodfellow, Ian, et al. “Generative adversarial nets.” Advances in neural information processing systems. 2014.
|
||||
|
||||
2. Mirza, Mehdi, and Simon Osindero. “Conditional generative adversarial nets.” arXiv preprint arXiv:1411.1784 (2014).
|
|
@ -125,6 +125,44 @@ class profile_scope(_call_no_record_scope):
|
|||
profiler.stop()
|
||||
self.report.extend(profiler.report())
|
||||
|
||||
class single_process_scope(_call_no_record_scope):
|
||||
""" single_process_scope
|
||||
|
||||
Code in this scope will only be executed by single process.
|
||||
|
||||
example::
|
||||
|
||||
with jt.single_process_scope(root=0):
|
||||
......
|
||||
|
||||
@jt.single_process_scope(root=0)
|
||||
def xxx():
|
||||
...
|
||||
"""
|
||||
def __init__(self, rank=0):
|
||||
self.rank = rank
|
||||
|
||||
def __enter__(self):
|
||||
global mpi
|
||||
from jittor.dataset import dataset
|
||||
self.mpi_backup = mpi
|
||||
mpi = dataset.mpi = None
|
||||
|
||||
def __exit__(self, *exc):
|
||||
global mpi
|
||||
from jittor.dataset import dataset
|
||||
mpi = dataset.mpi = self.mpi_backup
|
||||
|
||||
def __call__(self, func):
|
||||
global mpi
|
||||
def inner(*args, **kw):
|
||||
if mpi and mpi.world_rank() != self.rank:
|
||||
return
|
||||
with self:
|
||||
ret = func(*args, **kw)
|
||||
return ret
|
||||
return inner
|
||||
|
||||
def clean():
|
||||
import gc
|
||||
# make sure python do a full collection
|
||||
|
@ -177,7 +215,7 @@ def std(x):
|
|||
matsize *= i
|
||||
out=(x-x.mean()).sqr().sum()
|
||||
out=out/(matsize-1)
|
||||
out=out.sqrt()
|
||||
out=out.maximum(1e-6).sqrt()
|
||||
return out
|
||||
Var.std = std
|
||||
|
||||
|
@ -186,7 +224,7 @@ def norm(x, k, dim):
|
|||
if k==1:
|
||||
return x.abs().sum(dim)
|
||||
if k==2:
|
||||
return x.sqr().sum(dim).sqrt()
|
||||
return (x.sqr()).sum(dim).maximum(1e-6).sqrt()
|
||||
Var.norm = norm
|
||||
|
||||
origin_reshape = reshape
|
||||
|
@ -257,6 +295,14 @@ Var.masked_fill = masked_fill
|
|||
def sqr(x): return x*x
|
||||
Var.sqr = sqr
|
||||
|
||||
def argmax(x, dim:int, keepdims:bool=False):
|
||||
return x.arg_reduce("max", dim, keepdims)
|
||||
Var.argmax = argmax
|
||||
|
||||
def argmin(x, dim:int, keepdims:bool=False):
|
||||
return x.arg_reduce("min", dim, keepdims)
|
||||
Var.argmin = argmin
|
||||
|
||||
def attrs(var):
|
||||
return {
|
||||
"is_stop_fuse": var.is_stop_fuse(),
|
||||
|
|
|
@ -380,7 +380,7 @@ def setup_mpi():
|
|||
# share the 'environ' symbol.
|
||||
mpi = compile_custom_ops(mpi_src_files,
|
||||
extra_flags=f" {mpi_flags} ", return_module=True,
|
||||
dlopen_flags=os.RTLD_GLOBAL | os.RTLD_NOW)
|
||||
dlopen_flags=os.RTLD_GLOBAL | os.RTLD_NOW, gen_name_="jittor_mpi_core")
|
||||
mpi_ops = mpi.ops
|
||||
LOG.vv("Get mpi: "+str(mpi.__dict__.keys()))
|
||||
LOG.vv("Get mpi_ops: "+str(mpi_ops.__dict__.keys()))
|
||||
|
|
|
@ -561,7 +561,8 @@ def compile_custom_ops(
|
|||
filenames,
|
||||
extra_flags="",
|
||||
return_module=False,
|
||||
dlopen_flags=os.RTLD_GLOBAL | os.RTLD_NOW | os.RTLD_DEEPBIND):
|
||||
dlopen_flags=os.RTLD_GLOBAL | os.RTLD_NOW | os.RTLD_DEEPBIND,
|
||||
gen_name_ = ""):
|
||||
"""Compile custom ops
|
||||
filenames: path of op source files, filenames must be
|
||||
pairs of xxx_xxx_op.cc and xxx_xxx_op.h, and the
|
||||
|
@ -599,6 +600,8 @@ def compile_custom_ops(
|
|||
for name in srcs:
|
||||
assert name in headers, f"Header of op {name} not found"
|
||||
gen_name = "gen_ops_" + "_".join(headers.keys())
|
||||
if gen_name_ != "":
|
||||
gen_name = gen_name_
|
||||
if len(gen_name) > 100:
|
||||
gen_name = gen_name[:80] + "___hash" + str(hash(gen_name))
|
||||
|
||||
|
@ -815,12 +818,22 @@ with jit_utils.import_scope(import_flags):
|
|||
jit_utils.try_import_jit_utils_core()
|
||||
|
||||
python_path = sys.executable
|
||||
py3_config_path = sys.executable+"-config"
|
||||
assert os.path.isfile(python_path)
|
||||
if not os.path.isfile(py3_config_path) :
|
||||
py3_config_path = sys.executable + '3-config'
|
||||
py3_config_paths = [
|
||||
sys.executable + "-config",
|
||||
os.path.dirname(sys.executable) + f"/python3.{sys.version_info.minor}-config",
|
||||
f"/usr/bin/python3.{sys.version_info.minor}-config",
|
||||
os.path.dirname(sys.executable) + "/python3-config",
|
||||
]
|
||||
if "python_config_path" in os.environ:
|
||||
py3_config_paths.insert(0, os.environ["python_config_path"])
|
||||
|
||||
assert os.path.isfile(py3_config_path)
|
||||
for py3_config_path in py3_config_paths:
|
||||
if os.path.isfile(py3_config_path):
|
||||
break
|
||||
else:
|
||||
raise RuntimeError(f"python3.{sys.version_info.minor}-config "
|
||||
"not found in {py3_config_paths}, please specify "
|
||||
"enviroment variable 'python_config_path'")
|
||||
nvcc_path = env_or_try_find('nvcc_path', '/usr/local/cuda/bin/nvcc')
|
||||
gdb_path = try_find_exe('gdb')
|
||||
addr2line_path = try_find_exe('addr2line')
|
||||
|
|
|
@ -25,6 +25,8 @@ class MNIST(Dataset):
|
|||
[in] data_root(str): your data root.
|
||||
[in] train(bool): choose model train or val.
|
||||
[in] download(bool): Download data automatically if download is Ture.
|
||||
[in] batch_size(int): Data batch size.
|
||||
[in] shuffle(bool): Shuffle data if true.
|
||||
[in] transform(jittor.transform): transform data.
|
||||
|
||||
Example::
|
||||
|
@ -34,12 +36,19 @@ class MNIST(Dataset):
|
|||
for i, (imgs, target) in enumerate(train_loader):
|
||||
...
|
||||
'''
|
||||
def __init__(self, data_root=dataset_root+"/mnist_data/", train=True ,download=True, transform=None):
|
||||
def __init__(self, data_root=dataset_root+"/mnist_data/",
|
||||
train=True,
|
||||
download=True,
|
||||
batch_size = 16,
|
||||
shuffle = False,
|
||||
transform=None):
|
||||
# if you want to test resnet etc you should set input_channel = 3, because the net set 3 as the input dimensions
|
||||
super().__init__()
|
||||
self.data_root = data_root
|
||||
self.is_train = train
|
||||
self.transform = transform
|
||||
self.batch_size = batch_size
|
||||
self.shuffle = shuffle
|
||||
if download == True:
|
||||
self.download_url()
|
||||
|
||||
|
|
|
@ -10,8 +10,6 @@
|
|||
import numpy as np
|
||||
import os
|
||||
from PIL import Image
|
||||
import matplotlib.pyplot as plt
|
||||
import cv2
|
||||
from .dataset import Dataset, dataset_root
|
||||
|
||||
class VOC(Dataset):
|
||||
|
|
|
@ -15,4 +15,5 @@ from .mobilenet import *
|
|||
from . import mnasnet
|
||||
from .mnasnet import *
|
||||
from . import shufflenetv2
|
||||
from .shufflenetv2 import *
|
||||
from .shufflenetv2 import *
|
||||
from .res2net import res2net50, res2net101
|
||||
|
|
|
@ -13,6 +13,19 @@ import jittor.nn as nn
|
|||
__all__ = ['AlexNet', 'alexnet']
|
||||
|
||||
class AlexNet(nn.Module):
|
||||
""" AlexNet model architecture.
|
||||
|
||||
Args:
|
||||
|
||||
* num_classes: Number of classes. Default: 1000.
|
||||
|
||||
Example::
|
||||
|
||||
model = jittor.models.AlexNet(500)
|
||||
x = jittor.random([10,224,224,3])
|
||||
y = model(x) # [10, 500]
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, num_classes=1000):
|
||||
super(AlexNet, self).__init__()
|
||||
|
|
|
@ -16,6 +16,15 @@ def googlenet(**kwargs):
|
|||
return GoogLeNet(**kwargs)
|
||||
|
||||
class GoogLeNet(nn.Module):
|
||||
""" GoogLeNet model architecture.
|
||||
|
||||
Args:
|
||||
|
||||
* num_classes: Number of classes. Default: 1000.
|
||||
* aux_logits: If True, add an auxiliary branch that can improve training. Default: True
|
||||
* init_weights: Defualt: True.
|
||||
* blocks: List of three blocks, [conv_block, inception_block, inception_aux_block]. If None, will use [BasicConv2d, Inception, InceptionAux] instead. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self, num_classes=1000, aux_logits=True, init_weights=True, blocks=None):
|
||||
super(GoogLeNet, self).__init__()
|
||||
|
|
|
@ -7,7 +7,16 @@ def inception_v3(pretrained=False, progress=True, **kwargs):
|
|||
return Inception3(**kwargs)
|
||||
|
||||
class Inception3(nn.Module):
|
||||
""" Inceptionv3 model architecture.
|
||||
|
||||
Args:
|
||||
|
||||
* num_classes: Number of classes. Default: 1000.
|
||||
* aux_logits: If True, add an auxiliary branch that can improve training. Default: True
|
||||
* inception_blocks: List of seven blocks, [conv_block, inception_a, inception_b, inception_c, inception_d, inception_e, inception_aux]. If None, will use [BasicConv2d, InceptionA, InceptionB, InceptionC, InceptionD, InceptionE, InceptionAux] instead. Default: None.
|
||||
* init_weights: Defualt: True.
|
||||
"""
|
||||
|
||||
def __init__(self, num_classes=1000, aux_logits=True, inception_blocks=None, init_weights=True):
|
||||
super(Inception3, self).__init__()
|
||||
if (inception_blocks is None):
|
||||
|
|
|
@ -47,6 +47,14 @@ def _get_depths(alpha):
|
|||
return [_round_to_multiple_of((depth * alpha), 8) for depth in depths]
|
||||
|
||||
class MNASNet(nn.Module):
|
||||
""" MNASNet model architecture. version=2.
|
||||
|
||||
Args:
|
||||
|
||||
* alpha: Depth multiplier.
|
||||
* num_classes: Number of classes. Default: 1000.
|
||||
* dropout: Dropout probability of dropout layer.
|
||||
"""
|
||||
_version = 2
|
||||
|
||||
def __init__(self, alpha, num_classes=1000, dropout=0.2):
|
||||
|
|
|
@ -48,6 +48,17 @@ class InvertedResidual(nn.Module):
|
|||
return self.conv(x)
|
||||
|
||||
class MobileNetV2(nn.Module):
|
||||
""" MobileNetV2 model architecture.
|
||||
|
||||
Args:
|
||||
|
||||
* num_classes: Number of classes. Default: 1000.
|
||||
* width_mult: Width multiplier - adjusts number of channels in each layer by this amount. Default: 1.0.
|
||||
* init_weights: Defualt: True.
|
||||
* inverted_residual_setting: Network structure
|
||||
* round_nearest: Round the number of channels in each layer to be a multiple of this number. Set to 1 to turn off rounding. Default: 8.
|
||||
* block: Module specifying inverted residual building block for mobilenet. If None, use InvertedResidual instead. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self, num_classes=1000, width_mult=1.0, inverted_residual_setting=None, round_nearest=8, block=None):
|
||||
super(MobileNetV2, self).__init__()
|
||||
|
|
|
@ -0,0 +1,184 @@
|
|||
import jittor as jt
|
||||
from jittor import nn
|
||||
from jittor import Module
|
||||
from jittor import init
|
||||
from jittor.contrib import concat, argmax_pool
|
||||
import math
|
||||
|
||||
class Bottle2neck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, baseWidth=26, scale = 4, stype='normal'):
|
||||
""" Constructor
|
||||
Args:
|
||||
inplanes: input channel dimensionality
|
||||
planes: output channel dimensionality
|
||||
stride: conv stride. Replaces pooling layer.
|
||||
downsample: None when stride = 1
|
||||
baseWidth: basic width of conv3x3
|
||||
scale: number of scale.
|
||||
type: 'normal': normal set. 'stage': first block of a new stage.
|
||||
"""
|
||||
super(Bottle2neck, self).__init__()
|
||||
|
||||
width = int(math.floor(planes * (baseWidth/64.0)))
|
||||
self.conv1 = nn.Conv(inplanes, width*scale, kernel_size=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm(width*scale)
|
||||
assert scale > 1, 'Res2Net degenerates to ResNet when scales = 1.'
|
||||
if scale == 1:
|
||||
self.nums = 1
|
||||
else:
|
||||
self.nums = scale -1
|
||||
if stype == 'stage':
|
||||
self.pool = nn.Pool(kernel_size=3, stride = stride, padding=1, op='mean')
|
||||
self.convs = nn.ModuleList()
|
||||
self.bns = nn.ModuleList()
|
||||
for i in range(self.nums):
|
||||
self.convs.append(nn.Conv(width, width, kernel_size=3, stride = stride, dilation=dilation, padding=dilation, bias=False))
|
||||
self.bns.append(nn.BatchNorm(width))
|
||||
|
||||
self.conv3 = nn.Conv(width*scale, planes * self.expansion, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm(planes * self.expansion)
|
||||
|
||||
self.relu = nn.ReLU()
|
||||
self.downsample = downsample
|
||||
self.stype = stype
|
||||
self.scale = scale
|
||||
self.width = width
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
|
||||
def execute(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
spx = out
|
||||
|
||||
outs = []
|
||||
for i in range(self.nums):
|
||||
if i==0 or self.stype=='stage':
|
||||
sp = spx[:, i*self.width: (i+1)*self.width]
|
||||
else:
|
||||
sp = sp + spx[:, i*self.width: (i+1)*self.width]
|
||||
sp = self.convs[i](sp)
|
||||
sp = self.relu(self.bns[i](sp))
|
||||
outs.append(sp)
|
||||
if self.stype=='normal' or self.stride==1:
|
||||
outs.append(spx[:, self.nums*self.width: (self.nums+1)*self.width])
|
||||
elif self.stype=='stage':
|
||||
outs.append(self.pool(spx[:, self.nums*self.width: (self.nums+1)*self.width]))
|
||||
out = concat(outs, 1)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Res2Net(Module):
|
||||
def __init__(self, block, layers, output_stride, baseWidth = 26, scale = 4):
|
||||
super(Res2Net, self).__init__()
|
||||
self.baseWidth = baseWidth
|
||||
self.scale = scale
|
||||
self.inplanes = 64
|
||||
blocks = [1, 2, 4]
|
||||
if output_stride == 16:
|
||||
strides = [1, 2, 2, 1]
|
||||
dilations = [1, 1, 1, 2]
|
||||
elif output_stride == 8:
|
||||
strides = [1, 2, 1, 1]
|
||||
dilations = [1, 1, 2, 4]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
# Modules
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv(3, 32, 3, 2, 1, bias=False),
|
||||
nn.BatchNorm(32),
|
||||
nn.ReLU(),
|
||||
nn.Conv(32, 32, 3, 1, 1, bias=False),
|
||||
nn.BatchNorm(32),
|
||||
nn.ReLU(),
|
||||
nn.Conv(32, 64, 3, 1, 1, bias=False)
|
||||
)
|
||||
self.bn1 = nn.BatchNorm(64)
|
||||
self.relu = nn.ReLU()
|
||||
# self.maxpool = nn.Pool(kernel_size=3, stride=2, padding=1)
|
||||
|
||||
self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0])
|
||||
self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1])
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2])
|
||||
self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3])
|
||||
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Pool(kernel_size=stride, stride=stride,
|
||||
ceil_mode=True, op='mean'),
|
||||
nn.Conv(self.inplanes, planes * block.expansion,
|
||||
kernel_size=1, stride=1, bias=False),
|
||||
nn.BatchNorm(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, dilation, downsample,
|
||||
stype='stage', baseWidth = self.baseWidth, scale=self.scale))
|
||||
self.inplanes = planes * block.expansion
|
||||
for i in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes, dilation=dilation, baseWidth = self.baseWidth, scale=self.scale))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _make_MG_unit(self, block, planes, blocks, stride=1, dilation=1):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Pool(kernel_size=stride, stride=stride,
|
||||
ceil_mode=True, op='mean'),
|
||||
nn.Conv(self.inplanes, planes * block.expansion,
|
||||
kernel_size=1, stride=1, bias=False),
|
||||
nn.BatchNorm(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, dilation=blocks[0]*dilation,
|
||||
downsample=downsample, stype='stage', baseWidth = self.baseWidth, scale=self.scale))
|
||||
self.inplanes = planes * block.expansion
|
||||
for i in range(1, len(blocks)):
|
||||
layers.append(block(self.inplanes, planes, stride=1,
|
||||
dilation=blocks[i]*dilation, baseWidth = self.baseWidth, scale=self.scale))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def execute(self, input):
|
||||
|
||||
x = self.conv1(input)
|
||||
x = self.bn1(x)
|
||||
|
||||
x = self.relu(x)
|
||||
x = argmax_pool(x, 2, 2)
|
||||
x = self.layer1(x)
|
||||
low_level_feat = x
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
|
||||
x = self.layer4(x)
|
||||
return x, low_level_feat
|
||||
|
||||
def res2net50(output_stride):
|
||||
model = Res2Net(Bottle2neck, [3,4,6,3], output_stride)
|
||||
return model
|
||||
|
||||
def res2net101(output_stride):
|
||||
model = Res2Net(Bottle2neck, [3,4,23,3], output_stride)
|
||||
return model
|
|
@ -15,10 +15,14 @@ __all__ = ['ResNet', 'Resnet18', 'Resnet34', 'Resnet50', 'Resnet101', 'Resnet152
|
|||
'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2']
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
||||
return nn.Conv(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation)
|
||||
conv=nn.Conv(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation)
|
||||
jt.init.relu_invariant_gauss_(conv.weight, mode="fan_out")
|
||||
return conv
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1):
|
||||
return nn.Conv(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||
conv=nn.Conv(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||
jt.init.relu_invariant_gauss_(conv.weight, mode="fan_out")
|
||||
return conv
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
@ -102,6 +106,7 @@ class ResNet(nn.Module):
|
|||
self.groups = groups
|
||||
self.base_width = width_per_group
|
||||
self.conv1 = nn.Conv(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
|
||||
jt.init.relu_invariant_gauss_(self.conv1.weight, mode="fan_out")
|
||||
self.bn1 = norm_layer(self.inplanes)
|
||||
self.relu = nn.Relu()
|
||||
self.maxpool = nn.Pool(kernel_size=3, stride=2, padding=1, op='maximum')
|
||||
|
@ -162,6 +167,16 @@ def Resnet50(**kwargs):
|
|||
resnet50 = Resnet50
|
||||
|
||||
def Resnet101(**kwargs):
|
||||
"""
|
||||
ResNet-101 model architecture.
|
||||
|
||||
Example::
|
||||
|
||||
model = jittor.models.Resnet101()
|
||||
x = jittor.random([10,224,224,3])
|
||||
y = model(x) # [10, 1000]
|
||||
|
||||
"""
|
||||
return _resnet(Bottleneck, [3, 4, 23, 3], **kwargs)
|
||||
resnet101 = Resnet101
|
||||
|
||||
|
|
|
@ -84,8 +84,11 @@ def cross_entropy_loss(output, target, ignore_index=None):
|
|||
def mse_loss(output, target):
|
||||
return (output-target).sqr().mean()
|
||||
|
||||
def bce_loss(output, target):
|
||||
return - (target * jt.log(jt.maximum(output, 1e-20)) + (1 - target) * jt.log(jt.maximum(1 - output, 1e-20))).mean()
|
||||
def bce_loss(output, target, size_average=True):
|
||||
if size_average:
|
||||
return - (target * jt.log(jt.maximum(output, 1e-20)) + (1 - target) * jt.log(jt.maximum(1 - output, 1e-20))).mean()
|
||||
else:
|
||||
return - (target * jt.log(jt.maximum(output, 1e-20)) + (1 - target) * jt.log(jt.maximum(1 - output, 1e-20))).sum()
|
||||
|
||||
def l1_loss(output, target):
|
||||
return (output-target).abs().mean()
|
||||
|
@ -105,8 +108,8 @@ class MSELoss(Module):
|
|||
class BCELoss(Module):
|
||||
def __init__(self):
|
||||
pass
|
||||
def execute(self, output, target):
|
||||
return bce_loss(output, target)
|
||||
def execute(self, output, target, size_average=True):
|
||||
return bce_loss(output, target, size_average)
|
||||
|
||||
class L1Loss(Module):
|
||||
def __init__(self):
|
||||
|
@ -118,9 +121,9 @@ class BCEWithLogitsLoss(Module):
|
|||
def __init__(self):
|
||||
self.sigmoid = Sigmoid()
|
||||
self.bce = BCELoss()
|
||||
def execute(self, output, target):
|
||||
def execute(self, output, target, size_average=True):
|
||||
output = self.sigmoid(output)
|
||||
output = self.bce(output, target)
|
||||
output = self.bce(output, target, size_average)
|
||||
return output
|
||||
|
||||
def softmax(x, dim = None):
|
||||
|
@ -279,9 +282,14 @@ class Conv(Module):
|
|||
assert in_channels % groups == 0, 'in_channels must be divisible by groups'
|
||||
assert out_channels % groups == 0, 'out_channels must be divisible by groups'
|
||||
|
||||
self.weight = init.relu_invariant_gauss([out_channels, in_channels//groups, Kh, Kw], dtype="float", mode="fan_out")
|
||||
# self.weight = init.relu_invariant_gauss([out_channels, in_channels//groups, Kh, Kw], dtype="float", mode="fan_out")
|
||||
self.weight = init.invariant_uniform([out_channels, in_channels//groups, Kh, Kw], dtype="float")
|
||||
if bias:
|
||||
self.bias = init.uniform([out_channels], dtype="float", low=-1, high=1)
|
||||
fan=1
|
||||
for i in self.weight.shape[1:]:
|
||||
fan *= i
|
||||
bound = 1 / math.sqrt(fan)
|
||||
self.bias = init.uniform([out_channels], dtype="float", low=-bound, high=bound)
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
|
@ -499,10 +507,10 @@ class PixelShuffle(Module):
|
|||
def execute(self, x):
|
||||
n,c,h,w = x.shape
|
||||
r = self.upscale_factor
|
||||
assert c%(r**2)==0, f"input channel needs to be divided by upscale_factor's square in PixelShuffle"
|
||||
assert c%(r*r)==0, f"input channel needs to be divided by upscale_factor's square in PixelShuffle"
|
||||
return x.reindex([n,int(c/r**2),h*r,w*r], [
|
||||
"i0",
|
||||
f"i1*{r**2}+i2%{r}*{r}+i3%{r}",
|
||||
f"i1*{r*r}+i2%{r}*{r}+i3%{r}",
|
||||
f"i2/{r}",
|
||||
f"i3/{r}"
|
||||
])
|
||||
|
@ -519,30 +527,58 @@ class Sigmoid(Module):
|
|||
def execute(self, x) :
|
||||
return x.sigmoid()
|
||||
|
||||
def resize(x, size, mode="nearest"):
|
||||
img = x
|
||||
n,c,h,w = x.shape
|
||||
H,W = size
|
||||
new_size = [n,c,H,W]
|
||||
nid, cid, hid, wid = jt.index(new_size)
|
||||
x = hid * h / H
|
||||
y = wid * w / W
|
||||
class Resize(Module):
|
||||
def __init__(self, size, mode="nearest", align_corners=False):
|
||||
super().__init__()
|
||||
self.size = size
|
||||
self.mode = mode
|
||||
self.align_corners = align_corners
|
||||
def execute(self, x):
|
||||
return resize(x, self.size, self.mode, self.align_corners)
|
||||
|
||||
def _interpolate(img, x, y, ids, mode):
|
||||
if mode=="nearest":
|
||||
return img.reindex([nid, cid, x.floor(), y.floor()])
|
||||
return img.reindex([*ids, x.floor(), y.floor()])
|
||||
if mode=="bilinear":
|
||||
fx, fy = x.floor(), y.floor()
|
||||
cx, cy = fx+1, fy+1
|
||||
dx, dy = x-fx, y-fy
|
||||
a = img.reindex_var([nid, cid, fx, fy])
|
||||
b = img.reindex_var([nid, cid, cx, fy])
|
||||
c = img.reindex_var([nid, cid, fx, cy])
|
||||
d = img.reindex_var([nid, cid, cx, cy])
|
||||
a = img.reindex_var([*ids, fx, fy])
|
||||
b = img.reindex_var([*ids, cx, fy])
|
||||
c = img.reindex_var([*ids, fx, cy])
|
||||
d = img.reindex_var([*ids, cx, cy])
|
||||
dnx, dny = 1-dx, 1-dy
|
||||
ab = dx*b + dnx*a
|
||||
cd = dx*d + dnx*c
|
||||
o = ab*dny + cd*dy
|
||||
return o
|
||||
raise(f"Not support {interpolation}")
|
||||
raise(f"Not support interpolation mode: {mode}")
|
||||
|
||||
def resize(img, size, mode="nearest", align_corners=False):
|
||||
n,c,h,w = img.shape
|
||||
H,W = size
|
||||
nid, cid, hid, wid = jt.index((n,c,H,W))
|
||||
if align_corners:
|
||||
x = hid * ((h-1) / max(1, H-1))
|
||||
y = wid * ((w-1) / max(1, W-1))
|
||||
else:
|
||||
x = hid * (h / H) + (h/H*0.5 - 0.5)
|
||||
if H>h: x = x.clamp(0, h-1)
|
||||
y = wid * (w / W) + (w/W*0.5 - 0.5)
|
||||
if W>w: y = y.clamp(0, w-1)
|
||||
return _interpolate(img, x, y, (nid,cid), mode)
|
||||
|
||||
def upsample(img, size, mode="nearest", align_corners=False):
|
||||
n,c,h,w = img.shape
|
||||
H,W = size
|
||||
nid, cid, hid, wid = jt.index((n,c,H,W))
|
||||
if align_corners:
|
||||
x = hid * ((h-1) / max(1, H-1))
|
||||
y = wid * ((w-1) / max(1, W-1))
|
||||
else:
|
||||
x = hid * (h / H)
|
||||
y = wid * (w / W)
|
||||
return _interpolate(img, x, y, (nid,cid), mode)
|
||||
|
||||
class Upsample(Module):
|
||||
def __init__(self, scale_factor=None, mode='nearest'):
|
||||
|
@ -550,11 +586,17 @@ class Upsample(Module):
|
|||
self.mode = mode
|
||||
|
||||
def execute(self, x):
|
||||
return resize(x, size=(int(x.shape[2]*self.scale_factor[0]), int(x.shape[3]*self.scale_factor[1])), mode=self.mode)
|
||||
return upsample(x,
|
||||
size=(
|
||||
int(x.shape[2]*self.scale_factor[0]),
|
||||
int(x.shape[3]*self.scale_factor[1])),
|
||||
mode=self.mode)
|
||||
|
||||
class Sequential(Module):
|
||||
def __init__(self, *args):
|
||||
self.layers = list(args)
|
||||
self.layers = []
|
||||
for mod in args:
|
||||
self.append(mod)
|
||||
def __getitem__(self, idx):
|
||||
return self.layers[idx]
|
||||
def execute(self, x):
|
||||
|
@ -573,6 +615,8 @@ class Sequential(Module):
|
|||
if callback_leave:
|
||||
callback_leave(parents, k, self, n_children)
|
||||
def append(self, mod):
|
||||
assert callable(mod), f"Module <{type(mod)}> is not callable"
|
||||
assert not isinstance(mod, type), f"Module is not a type"
|
||||
self.layers.append(mod)
|
||||
|
||||
ModuleList = Sequential
|
||||
|
|
|
@ -66,7 +66,7 @@ class Optimizer(object):
|
|||
g.assign(g.mpi_all_reduce("mean"))
|
||||
if self.n_step % self.param_sync_iter == 0:
|
||||
for p in params:
|
||||
p.assign(p.mpi_all_reduce("mean"))
|
||||
p.assign(p.mpi_broadcast())
|
||||
self.n_step += 1
|
||||
|
||||
# set up grads in param_groups
|
||||
|
@ -131,6 +131,42 @@ class SGD(Optimizer):
|
|||
p -= v * lr
|
||||
p.detach_inplace()
|
||||
|
||||
class RMSprop(Optimizer):
|
||||
""" RMSprop Optimizer.
|
||||
Args:
|
||||
params(list): parameters of model.
|
||||
lr(float): learning rate.
|
||||
eps(float): term added to the denominator to avoid division by zero, default 1e-8.
|
||||
alpha(float): smoothing constant, default 0.99.
|
||||
|
||||
Example:
|
||||
optimizer = nn.RMSprop(model.parameters(), lr)
|
||||
optimizer.step(loss)
|
||||
"""
|
||||
def __init__(self, params, lr=1e-2, eps=1e-8, alpha=0.99):
|
||||
super().__init__(params, lr)
|
||||
self.eps = eps
|
||||
self.alpha = alpha
|
||||
|
||||
# initialize required arguments for each param_groups
|
||||
for pg in self.param_groups:
|
||||
values = pg["values"] = []
|
||||
for p in pg["params"]:
|
||||
values.append(jt.zeros(p.shape, p.dtype).stop_fuse().stop_grad())
|
||||
|
||||
def step(self, loss):
|
||||
self.pre_step(loss)
|
||||
for pg in self.param_groups:
|
||||
# get arguments from each param_groups
|
||||
lr = pg.get("lr", self.lr)
|
||||
eps = pg.get("eps", self.eps)
|
||||
alpha = pg.get("alpha", self.alpha)
|
||||
for p, g, v in zip(pg["params"], pg["grads"], pg["values"]):
|
||||
if p.is_stop_grad(): continue
|
||||
v.assign(alpha * v + (1-alpha) * g * g)
|
||||
p -= lr * g / (jt.sqrt(v) + eps)
|
||||
p.detach_inplace()
|
||||
|
||||
class Adam(Optimizer):
|
||||
""" Adam Optimizer.
|
||||
|
||||
|
|
|
@ -34,7 +34,14 @@ class Pool(Module):
|
|||
h = (H+self.padding*2-self.kernel_size + self.stride - 1)//self.stride+1
|
||||
w = (W+self.padding*2-self.kernel_size + self.stride - 1)//self.stride+1
|
||||
|
||||
if self.op in ['maximum', 'minimum', 'mean'] and not self.count_include_pad:
|
||||
if self.op in ['maximum', 'minimum', 'mean']:
|
||||
if self.op == 'mean':
|
||||
if self.count_include_pad:
|
||||
count = f"int count = {self.kernel_size*self.kernel_size};"
|
||||
else:
|
||||
count = "int count = (k2_ - k2) * (k3_ - k3);"
|
||||
else:
|
||||
count = ""
|
||||
forward_body = f'''{{
|
||||
int k3 = i3*{self.stride}-{self.padding};
|
||||
int k2 = i2*{self.stride}-{self.padding};
|
||||
|
@ -43,7 +50,7 @@ class Pool(Module):
|
|||
k3 = max(0, k3);
|
||||
k2 = max(0, k2);
|
||||
@out(i0, i1, i2, i3) = init_{self.op}(out_type);
|
||||
{"int count = (k2_ - k2) * (k3_ - k3);" if self.op == "mean" else ""}
|
||||
{count}
|
||||
for (int p = k2; p < k2_; ++p)
|
||||
for (int q = k3; q < k3_; ++q)
|
||||
@out(i0, i1, i2, i3) = {self.op}(out_type, @out(i0, i1, i2, i3), @in0(i0, i1, p, q));
|
||||
|
@ -55,7 +62,7 @@ class Pool(Module):
|
|||
int k2_ = min(k2 + {self.kernel_size}, in0_shape2);
|
||||
k3 = max(0, k3);
|
||||
k2 = max(0, k2);
|
||||
{"int count = (k2_ - k2) * (k3_ - k3);" if self.op == "mean" else ""}
|
||||
{count}
|
||||
int bo=1;
|
||||
for (int p = k2; p < k2_ && bo; ++p)
|
||||
for (int q = k3; q < k3_ && bo; ++q) {{
|
||||
|
@ -139,6 +146,7 @@ class Pool(Module):
|
|||
'''])
|
||||
return out
|
||||
else:
|
||||
# TODO: backward
|
||||
xx = x.reindex([N,C,h,w,self.kernel_size,self.kernel_size], [
|
||||
"i0", # Nid
|
||||
"i1", # Cid
|
||||
|
|
|
@ -24,6 +24,36 @@ try:
|
|||
except:
|
||||
skip_this_test = True
|
||||
|
||||
class OldPool(Module):
|
||||
def __init__(self, kernel_size, stride=None, padding=0, dilation=None, return_indices=None, ceil_mode=False, count_include_pad=True, op="maximum"):
|
||||
assert dilation == None
|
||||
assert return_indices == None
|
||||
self.kernel_size = kernel_size
|
||||
self.op = op
|
||||
self.stride = stride if stride else kernel_size
|
||||
self.padding = padding
|
||||
self.ceil_mode = ceil_mode
|
||||
self.count_include_pad = count_include_pad and padding != 0
|
||||
|
||||
def execute(self, x):
|
||||
N,C,H,W = x.shape
|
||||
if self.ceil_mode == False:
|
||||
h = (H+self.padding*2-self.kernel_size)//self.stride+1
|
||||
w = (W+self.padding*2-self.kernel_size)//self.stride+1
|
||||
else:
|
||||
h = (H+self.padding*2-self.kernel_size + self.stride - 1)//self.stride+1
|
||||
w = (W+self.padding*2-self.kernel_size + self.stride - 1)//self.stride+1
|
||||
|
||||
# TODO: backward
|
||||
xx = x.reindex([N,C,h,w,self.kernel_size,self.kernel_size], [
|
||||
"i0", # Nid
|
||||
"i1", # Cid
|
||||
f"i2*{self.stride}-{self.padding}+i4", # Hid
|
||||
f"i3*{self.stride}-{self.padding}+i5", # Wid
|
||||
])
|
||||
return xx.reduce(self.op, [4,5])
|
||||
|
||||
|
||||
def check(jt_model, torch_model, shape, near_data):
|
||||
if (near_data):
|
||||
assert shape[0] * shape[1] * shape[2] * shape[3] % 8 == 0
|
||||
|
@ -51,6 +81,20 @@ class TestArgPoolOp(unittest.TestCase):
|
|||
def test_cuda(self):
|
||||
jt_model = jt.nn.Sequential(Pool(2, 2, 0), Pool(2, 2, 0), Pool(2, 2, 0, ceil_mode=True), Pool(2, 2, 0), Pool(2, 2, 0), Pool(3, 1, 1))
|
||||
torch_model = Sequential(MaxPool2d(2, 2, 0), MaxPool2d(2, 2, 0), MaxPool2d(2, 2, 0, ceil_mode=True), MaxPool2d(2, 2, 0), MaxPool2d(2, 2, 0), MaxPool2d(3, 1, 1))
|
||||
shape = [2, 3, 300, 300]
|
||||
check(jt_model, torch_model, shape, False)
|
||||
shape = [2, 3, 157, 300]
|
||||
check(jt_model, torch_model, shape, False)
|
||||
for i in range(10):
|
||||
check(jt_model, torch_model, [1,1,300,300], True)
|
||||
|
||||
@unittest.skipIf(True, "TODO: cannot pass this test, fix me")
|
||||
@unittest.skipIf(not jt.compiler.has_cuda, "No cuda found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_cuda_old_pool(self):
|
||||
from torch.nn import AvgPool2d
|
||||
jt_model = OldPool(3, 1, 1, op="mean")
|
||||
torch_model = AvgPool2d(3, 1, 1)
|
||||
shape = [64, 64, 300, 300]
|
||||
check(jt_model, torch_model, shape, False)
|
||||
shape = [32, 128, 157, 300]
|
||||
|
|
|
@ -23,13 +23,14 @@ class TestMem(unittest.TestCase):
|
|||
one_g = np.ones((1024*1024*1024//4,), "float32")
|
||||
|
||||
meminfo = jt.get_mem_info()
|
||||
n = int(meminfo.total_cuda_ram // (1024**3) * 1.5)
|
||||
n = int(meminfo.total_cuda_ram // (1024**3) * 0.6)
|
||||
|
||||
for i in range(n):
|
||||
a = jt.array(one_g)
|
||||
b = a + 1
|
||||
b.sync()
|
||||
backups.append((a,b))
|
||||
jt.sync_all(True)
|
||||
backups = []
|
||||
|
||||
|
||||
|
|
|
@ -61,8 +61,11 @@ class test_models(unittest.TestCase):
|
|||
|
||||
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
@torch.no_grad()
|
||||
def test_models(self):
|
||||
with torch.no_grad():
|
||||
self.run_models()
|
||||
|
||||
def run_models(self):
|
||||
def to_cuda(x):
|
||||
if jt.has_cuda:
|
||||
return x.cuda()
|
||||
|
|
|
@ -72,6 +72,10 @@ def run_mpi_test(num_procs, name):
|
|||
class TestMpiEntry(unittest.TestCase):
|
||||
def test_entry(self):
|
||||
run_mpi_test(2, "test_mpi")
|
||||
|
||||
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
|
||||
def test_mpi_resnet_entry(self):
|
||||
run_mpi_test(2, "test_resnet")
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -15,7 +15,7 @@ tests = []
|
|||
for mdname in os.listdir(dirname):
|
||||
if not mdname.endswith(".src.md"): continue
|
||||
# temporary disable model_test
|
||||
if "LSGAN" in mdname: continue
|
||||
if "GAN" in mdname: continue
|
||||
tests.append(mdname[:-3])
|
||||
|
||||
try:
|
||||
|
|
|
@ -158,7 +158,7 @@ class TestParallelPass3(unittest.TestCase):
|
|||
src = f.read()
|
||||
for i in range(tdim):
|
||||
assert f"tnum{i}" in src
|
||||
assert f"tnum{tdim}" not in src
|
||||
assert f"tnum{tdim}" not in src, f"tnum{tdim}"
|
||||
src_has_atomic = "atomic_add" in src or "atomicAdd" in src
|
||||
assert has_atomic == src_has_atomic
|
||||
assert np.allclose(a.data.sum(rdim), b), (b.sum(), a.data.sum())
|
||||
|
@ -176,7 +176,11 @@ class TestParallelPass3(unittest.TestCase):
|
|||
check(3, 1, 1, [0,1], 1)
|
||||
check(3, 1, 1, [0,1], 0, [0,0,2])
|
||||
check(3, 2, 2, [2], 0)
|
||||
check(3, 2, 1, [1], 0)
|
||||
if jt.flags.use_cuda:
|
||||
# loop is not merged so parallel depth 2
|
||||
check(3, 2, 2, [1], 1)
|
||||
else:
|
||||
check(3, 2, 1, [1], 0)
|
||||
check(3, 2, 2, [1], 1, merge=0)
|
||||
check(4, 2, 2, [2,3], 0)
|
||||
check(4, 2, 2, [0,3], 1)
|
||||
|
|
|
@ -99,16 +99,25 @@ class TestResizeAndCrop(unittest.TestCase):
|
|||
test_case(20, [1024, 1024], [1.2, 1.8][mid])
|
||||
test_case(20, [1024, 666], [0.8,1.0][mid])
|
||||
|
||||
def test_resize(self):
|
||||
import torch.nn.functional as F
|
||||
x = np.array(range(2*3*25)).reshape(2,3,5,5).astype("float32")
|
||||
for r_size in [3,4,5,6]:
|
||||
for align_corners in [True,False]:
|
||||
check_equal(x,
|
||||
jnn.Resize((r_size, r_size), 'bilinear', align_corners),
|
||||
lambda x: F.interpolate(x, size=(r_size, r_size), mode='bilinear',align_corners=align_corners))
|
||||
|
||||
def test_upsample(self):
|
||||
arr = np.random.randn(16,10,224,224)
|
||||
arr = np.random.randn(2,3,224,224)
|
||||
check_equal(arr, jnn.Upsample(scale_factor=2), tnn.Upsample(scale_factor=2))
|
||||
check_equal(arr, jnn.Upsample(scale_factor=0.2), tnn.Upsample(scale_factor=0.2))
|
||||
|
||||
def test_pixelshuffle(self):
|
||||
arr = np.random.randn(16,16,224,224)
|
||||
arr = np.random.randn(2,4,224,224)
|
||||
check_equal(arr, jnn.PixelShuffle(upscale_factor=2), tnn.PixelShuffle(upscale_factor=2))
|
||||
arr = np.random.randn(1,16*16,224,224)
|
||||
check_equal(arr, jnn.PixelShuffle(upscale_factor=16), tnn.PixelShuffle(upscale_factor=16))
|
||||
arr = np.random.randn(1,3*3,224,224)
|
||||
check_equal(arr, jnn.PixelShuffle(upscale_factor=3), tnn.PixelShuffle(upscale_factor=3))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -76,7 +76,7 @@ class TestResnet(unittest.TestCase):
|
|||
# print train info
|
||||
global prev
|
||||
pred = np.argmax(output, axis=1)
|
||||
acc = np.sum(target==pred)/self.batch_size
|
||||
acc = np.mean(target==pred)
|
||||
loss_list.append(loss[0])
|
||||
acc_list.append(acc)
|
||||
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAcc: {:.6f} \tTime:{:.3f}'
|
||||
|
@ -113,10 +113,14 @@ class TestResnet(unittest.TestCase):
|
|||
# Train Epoch: 0 [40/100 (40%)] Loss: 2.286762 Acc: 0.130000
|
||||
# Train Epoch: 0 [50/100 (50%)] Loss: 2.055014 Acc: 0.290000
|
||||
|
||||
assert jt.core.number_of_lived_vars() < 3500
|
||||
if jt.mpi:
|
||||
assert jt.core.number_of_lived_vars() < 3900, jt.core.number_of_lived_vars()
|
||||
else:
|
||||
assert jt.core.number_of_lived_vars() < 3500, jt.core.number_of_lived_vars()
|
||||
|
||||
jt.sync_all(True)
|
||||
assert np.mean(loss_list[-50:])<0.3
|
||||
assert np.mean(acc_list[-50:])>0.8
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Wenyang Zhou <576825820@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import os, sys
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
from jittor.test.test_mpi import run_mpi_test
|
||||
mpi = jt.compile_extern.mpi
|
||||
|
||||
from jittor.dataset.mnist import MNIST
|
||||
|
||||
def val1():
|
||||
dataloader = MNIST(train=False).set_attrs(batch_size=16)
|
||||
for i, (imgs, labels) in enumerate(dataloader):
|
||||
assert(imgs.shape[0]==8)
|
||||
if i == 5:
|
||||
break
|
||||
|
||||
@jt.single_process_scope(rank=0)
|
||||
def val2():
|
||||
dataloader = MNIST(train=False).set_attrs(batch_size=16)
|
||||
for i, (imgs, labels) in enumerate(dataloader):
|
||||
assert(imgs.shape[0]==16)
|
||||
if i == 5:
|
||||
break
|
||||
|
||||
@unittest.skipIf(mpi is None, "no inside mpirun")
|
||||
class TestSingleProcessScope(unittest.TestCase):
|
||||
def test_single_process_scope(self):
|
||||
val1()
|
||||
val2()
|
||||
|
||||
@unittest.skipIf(not jt.compile_extern.has_mpi, "no mpi found")
|
||||
class TestSingleProcessScopeEntry(unittest.TestCase):
|
||||
def test_entry(self):
|
||||
run_mpi_test(2, "test_single_process_scope")
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -71,6 +71,25 @@ def crop_and_resize(img, top, left, height, width, size, interpolation=Image.BIL
|
|||
img = resize(img, size, interpolation)
|
||||
return img
|
||||
|
||||
class Crop:
|
||||
"""Crop and the PIL Image to given size.
|
||||
|
||||
Args:
|
||||
|
||||
* top(int): top pixel indexes
|
||||
* left(int): left pixel indexes
|
||||
* height(int): image height
|
||||
* width(int): image width
|
||||
"""
|
||||
def __init__(self, top, left, height, width):
|
||||
self.top = top
|
||||
self.left = left
|
||||
self.height = height
|
||||
self.width = width
|
||||
def __call__(self, img):
|
||||
return crop(img, self.top, self.left, self.height, self.width)
|
||||
|
||||
|
||||
class RandomCropAndResize:
|
||||
"""Random crop and resize the given PIL Image to given size.
|
||||
|
||||
|
|
|
@ -10,6 +10,8 @@
|
|||
# Publish steps:
|
||||
# 1. build,push,upload docker image[jittor/jittor]
|
||||
# 2. build,push,upload docker image[jittor/jittor-cuda]
|
||||
# upload to pip:
|
||||
# rm -rf dist && python3.7 ./setup.py sdist && python3.7 -m twine upload dist/*
|
||||
import os
|
||||
|
||||
def run_cmd(cmd):
|
||||
|
|
2
setup.py
2
setup.py
|
@ -21,7 +21,7 @@ with open(os.path.join(path, "README.md"), "r", encoding='utf8') as fh:
|
|||
|
||||
setuptools.setup(
|
||||
name='jittor',
|
||||
version='1.1.3.1',
|
||||
version='1.1.4.4',
|
||||
# scripts=[],
|
||||
author="Jittor Group",
|
||||
author_email="ran.donglang@gmail.com",
|
||||
|
|
|
@ -149,6 +149,8 @@ static void init_ns() {
|
|||
FOR_ALL_NS(INIT_NS);
|
||||
ASSERT(NanoString::__ns_to_string.size()<=(1<<NanoString::_index_nbits));
|
||||
NanoString::__string_to_ns["sum"] = ns_add;
|
||||
NanoString::__string_to_ns["min"] = ns_minimum;
|
||||
NanoString::__string_to_ns["max"] = ns_maximum;
|
||||
LOGvv << "init __string_to_ns" << NanoString::__string_to_ns;
|
||||
LOGvv << "init __ns_to_string" << NanoString::__ns_to_string;
|
||||
}
|
||||
|
|
|
@ -34,18 +34,16 @@ static auto make_reshape = get_op_info("reshape")
|
|||
static auto make_reindex_reduce = get_op_info("reindex_reduce")
|
||||
.get_constructor<VarPtr, Var*, NanoString, NanoVector, vector<string>&&, vector<string>&&, vector<Var*>&&>();
|
||||
|
||||
ArgReduceOp::ArgReduceOp(Var* x, string op, int dim, bool keepdims)
|
||||
ArgReduceOp::ArgReduceOp(Var* x, NanoString op, int dim, bool keepdims)
|
||||
: x(x), op(op), dim(dim), keepdims(keepdims) {
|
||||
if (this->dim == -1)
|
||||
this->dim = x->shape.size() - 1;
|
||||
dim = this->dim;
|
||||
#ifdef HAS_CUDA
|
||||
if (use_cuda) {
|
||||
static std::vector<VarPtr>(*cub_arg_reduce)(Var*, Var*, string, bool) = nullptr;
|
||||
if (!cub_arg_reduce && has_op("cub_arg_reduce")) {
|
||||
cub_arg_reduce = get_op_info("cub_arg_reduce")
|
||||
.get_constructor<std::vector<VarPtr>, Var*, Var*, string, bool>();
|
||||
}
|
||||
static auto cub_arg_reduce = has_op("cub_arg_reduce") ?
|
||||
get_op_info("cub_arg_reduce").get_constructor<std::vector<VarPtr>, Var*, Var*, NanoString, bool>()
|
||||
: nullptr;
|
||||
if (cub_arg_reduce) {
|
||||
if (x->num<0) exe.run_sync(vector<Var*>({x}), true);
|
||||
int dims = x->shape.size();
|
||||
|
@ -162,7 +160,7 @@ void ArgReduceOp::jit_prepare() {
|
|||
add_jit_define("YDIM", JK::hex1(y->shape.size()));
|
||||
add_jit_define("KEEPDIMS", keepdims ? 1 : 0);
|
||||
add_jit_define("DIM", JK::hex1(dim));
|
||||
add_jit_define("CMP", op=="min" ? "<" : ">");
|
||||
add_jit_define("CMP", op==ns_minimum ? "<" : ">");
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
|
|
@ -14,11 +14,11 @@ namespace jittor {
|
|||
|
||||
struct ArgReduceOp : Op {
|
||||
Var* x, * y, * y_key;
|
||||
string op;
|
||||
NanoString op;
|
||||
int dim;
|
||||
bool keepdims;
|
||||
// @attrs(multiple_outputs)
|
||||
ArgReduceOp(Var* x, string op, int dim, bool keepdims);
|
||||
ArgReduceOp(Var* x, NanoString op, int dim, bool keepdims);
|
||||
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
|
||||
static VarPtr get_grad(Var* out, Var* dout, Var* v, int v_index, int dim, Var* y);
|
||||
void infer_shape() override;
|
||||
|
|
|
@ -36,23 +36,25 @@ struct ArgsortOp : Op {
|
|||
|
||||
* [in] dim: sort alone which dim
|
||||
|
||||
* [in] descending: the elements are sorted in descending order or not(default False).
|
||||
|
||||
* [in] dtype: type of return indexes
|
||||
|
||||
* [in] key: code for sorted key
|
||||
|
||||
* [in] compare: code for compare
|
||||
|
||||
* [out] index: index have the same size with sorted dim
|
||||
|
||||
|
||||
* [out] value: sorted value
|
||||
|
||||
|
||||
Example::
|
||||
|
||||
jt.sort([11,13,12])
|
||||
# return [0,2,1]
|
||||
jt.sort([11,13,12], key='-@x(i)')
|
||||
# return [1,2,0]
|
||||
jt.sort([11,13,12], key='@x(i)<@x(j)')
|
||||
# return [0,2,1]
|
||||
index, value = jt.argsort([11,13,12])
|
||||
# return [0 2 1], [11 12 13]
|
||||
index, value = jt.argsort([11,13,12], descending=True)
|
||||
# return [1 2 0], [13 12 11]
|
||||
index, value = jt.argsort([[11,13,12], [12,11,13]])
|
||||
# return [[0 2 1],[1 0 2]], [[11 12 13],[11 12 13]]
|
||||
index, value = jt.argsort([[11,13,12], [12,11,13]], dim=0)
|
||||
# return [[0 1 0],[1 0 1]], [[11 11 12],[12 13 13]]
|
||||
|
||||
*/
|
||||
// @attrs(multiple_outputs)
|
||||
|
|
|
@ -3,12 +3,13 @@
|
|||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "common.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
#ifdef JIT_cuda
|
||||
#define pow(T,a,b) ::powf(a,b)
|
||||
#define pow(T,a,b) ::pow(a,b)
|
||||
#define maximum(T,a,b) ::max(T(a), T(b))
|
||||
#define minimum(T,a,b) ::min(T(a), T(b))
|
||||
#else // JIT_cpu
|
||||
|
|
|
@ -61,7 +61,10 @@ ReduceOp::ReduceOp(Var* x, NanoString op, NanoVector dims, bool keepdims)
|
|||
reduce_mask |= 1<<dim;
|
||||
}
|
||||
}
|
||||
y = create_output(nullptr, binary_dtype_infer(ns, x, x));
|
||||
if (x->dtype() == ns_bool && ns == ns_add)
|
||||
y = create_output(nullptr, ns_int32);
|
||||
else
|
||||
y = create_output(nullptr, binary_dtype_infer(ns, x, x));
|
||||
}
|
||||
|
||||
ReduceOp::ReduceOp(Var* x, NanoString op, uint dims_mask, bool keepdims)
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "common.h"
|
||||
|
||||
namespace jittor {
|
||||
|
|
|
@ -41,12 +41,15 @@ static void move_rely(KernelIR* inner_loop, KernelIR* outer_loop, KernelIR* def)
|
|||
}
|
||||
}
|
||||
|
||||
static void tune_atomic(Pass* pass, KernelIR* ir, bool is_cuda, int tdim) {
|
||||
// sorder: Array that saves the allocation order of "tn"
|
||||
// sfunc: Array of function names
|
||||
static void tune_atomic(Pass* pass, KernelIR* ir, bool is_cuda, int tdim, vector<vector<int>> &sorder, vector<string> &sfunc) {
|
||||
LOGvvvv << "tune_atomic" << ir->children;
|
||||
vector<string> relys;
|
||||
vector<string> idx_name;
|
||||
vector<KernelIR*> atomics;
|
||||
vector<KernelIR*> loops;
|
||||
vector<int> nrely;
|
||||
vector<int> order;
|
||||
int tmp_cnt=0;
|
||||
for (uint i=0; i<ir->children.size(); i++) {
|
||||
|
@ -57,6 +60,7 @@ static void tune_atomic(Pass* pass, KernelIR* ir, bool is_cuda, int tdim) {
|
|||
atomics.clear();
|
||||
loops.clear();
|
||||
order.clear();
|
||||
nrely.clear();
|
||||
|
||||
c->dfs([&](unique_ptr<KernelIR>& p) {
|
||||
auto& code = p->attrs["code"];
|
||||
|
@ -71,6 +75,7 @@ static void tune_atomic(Pass* pass, KernelIR* ir, bool is_cuda, int tdim) {
|
|||
loops.push_back(loop);
|
||||
idx_name.push_back(loop->attrs["lvalue"]);
|
||||
order.push_back(loops.size()-1);
|
||||
nrely.push_back(-1);
|
||||
bool ok = true;
|
||||
while (1) {
|
||||
loop = loops.back();
|
||||
|
@ -90,6 +95,7 @@ static void tune_atomic(Pass* pass, KernelIR* ir, bool is_cuda, int tdim) {
|
|||
loops.push_back(loop2);
|
||||
idx_name.push_back(loop2->attrs["lvalue"]);
|
||||
order.push_back(loops.size()-1);
|
||||
nrely.push_back(-1);
|
||||
}
|
||||
// TODO: only support single loop children
|
||||
if (!ok) continue;
|
||||
|
@ -107,12 +113,25 @@ static void tune_atomic(Pass* pass, KernelIR* ir, bool is_cuda, int tdim) {
|
|||
for (uint l=0;l<order.size();l++)
|
||||
if (order[l]==sidx) sord=l;
|
||||
ASSERT(sord != -1);
|
||||
for (int l=sord;l;l--) order[l]=order[l-1];
|
||||
for (int l=sord;l;l--){
|
||||
order[l]=order[l-1];
|
||||
nrely[l]=nrely[l-1];
|
||||
}
|
||||
order[0]=sidx;
|
||||
nrely[0]=j;
|
||||
}
|
||||
}
|
||||
LOGvvvv << "atomic tuner order" << order;
|
||||
|
||||
vector<int> tnorder;
|
||||
uint si;
|
||||
for (si=0;si<order.size();si++)
|
||||
if (nrely[si]!=nrely[0]) break;
|
||||
for (int j=si-1;j>=0;j--) tnorder.push_back(order[j]);
|
||||
for (int j=order.size()-1;j>=si;j--) tnorder.push_back(order[j]);
|
||||
sorder.push_back(tnorder);
|
||||
sfunc.push_back(ir->attrs["lvalue"]);
|
||||
|
||||
// sort loop with order
|
||||
int count=0;
|
||||
for (auto j : order) {
|
||||
|
@ -199,12 +218,54 @@ void AtomicTunerPass::run() {
|
|||
if (is_cuda) choice=1;
|
||||
if (!choice) return;
|
||||
|
||||
vector<vector<int>> sorder;
|
||||
vector<string> sfunc;
|
||||
for (uint i=0; i<ir->before.size(); i++) {
|
||||
auto& func_call = ir->before[i];
|
||||
// TODO: remove this if
|
||||
if (func_call->get_attr("dtype") != "__global__ void") continue;
|
||||
tune_atomic(this, func_call.get(), is_cuda, 4);
|
||||
tune_atomic(this, func_call.get(), is_cuda, 4, sorder, sfunc);
|
||||
}
|
||||
|
||||
// Re-adjust the allocation order of "tn" according to the situation of atomic coverage, preferentially allocate the range not covered by atomic, for example:
|
||||
// for (op0_index_t id0 = tid0; id0<range0; id0+=tnum0) {
|
||||
// for (op1_index_t id1 = tid1; id1<range1; id1+=tnum1) {
|
||||
// for (op2_index_t id2 = tid2; id2<range2; id2+=tnum2) {
|
||||
// for (op3_index_t id3 = tid3; id3<range3; id3+=tnum3) {
|
||||
// ...
|
||||
// }
|
||||
// }
|
||||
// atomicAdd(...);
|
||||
// }
|
||||
// }
|
||||
// The allocation order of "tn" will be: tn1, tn0, tn3, tn2
|
||||
for (uint j=0;j<sfunc.size();j++)
|
||||
for (uint i=0; i<ir->children.size(); i++) {
|
||||
auto& func_call = ir->children[i];
|
||||
int bo=0;
|
||||
for (uint k=0; k<func_call->children.size(); k++){
|
||||
auto& save = func_call->children[k];
|
||||
if (save->has_attr("loop_func") && save->attrs["loop_func"]==sfunc[j]){
|
||||
bo=1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!bo) continue;
|
||||
uint k;
|
||||
for (k=0; k<func_call->children.size(); k++){
|
||||
auto& save = func_call->children[k];
|
||||
if (save->has_attr("lvalue") && save->attrs["lvalue"].find("tn")==0) break;
|
||||
}
|
||||
for (uint l=0;l<sorder[j].size();l++){
|
||||
for (uint p=0; p<func_call->children.size(); p++){
|
||||
auto& save = func_call->children[p];
|
||||
if (save->has_attr("lvalue") && save->attrs["lvalue"].find("tn"+S(sorder[j][l]))==0){
|
||||
func_call->children[p]->swap(*func_call->children[k++]);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
ir->remove_all_unused();
|
||||
}
|
||||
|
||||
|
|
|
@ -264,12 +264,9 @@ void ParallelPass::run() {
|
|||
string nums = rvalues.at(0);
|
||||
for (int i=1; i<rvalues.size(); i++)
|
||||
nums+="*"+rvalues[i];
|
||||
if (fix_thread_num)
|
||||
new_block.push_back("int thread_num=" + S(thread_num) + ");");
|
||||
else
|
||||
new_block.push_back("int thread_num=min(1<<(NanoVector::get_nbits("+nums+")-2)," + S(thread_num) + ");");
|
||||
|
||||
new_block.push_back("int thread_num=" + S(thread_num) + ";");
|
||||
new_block.push_back("int thread_num_left=thread_num;");
|
||||
|
||||
for (int j=ncs.size()-1; j>=0; j--) {
|
||||
auto& rv = rvalues[j];
|
||||
new_block.push_back("int tn"+S(j)+
|
||||
|
@ -291,7 +288,7 @@ void ParallelPass::run() {
|
|||
// omp func call
|
||||
// we set num_threads in code
|
||||
new_func_call->push_back(
|
||||
"#pragma omp parallel num_threads("+S(thread_num)+")",
|
||||
"#pragma omp parallel num_threads(thread_num)",
|
||||
&new_func_call->before
|
||||
);
|
||||
} else {
|
||||
|
@ -344,6 +341,15 @@ void ParallelPass::run() {
|
|||
new_func_def->insert(0, new_tid_def.children);
|
||||
new_func_def->swap(*func_def, true);
|
||||
new_block.swap(*func_call, true);
|
||||
auto code = func_def->to_string();
|
||||
bool has_atomic = code.find("atomic") != string::npos;
|
||||
if (!fix_thread_num) {
|
||||
if (has_atomic) {
|
||||
func_call->find_define("thread_num")->attrs["rvalue"] = "min(1<<max((NanoVector::get_nbits(" + nums + "/16)-2),0)," + S(thread_num) + ")";
|
||||
} else {
|
||||
func_call->find_define("thread_num")->attrs["rvalue"] = "min(1<<max((NanoVector::get_nbits(" + nums + ")-2),0)," + S(thread_num) + ")";
|
||||
}
|
||||
}
|
||||
}
|
||||
ir->remove_all_unused();
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue