Merge branch 'master' of github.com:Jittor/jittor

This commit is contained in:
Dun Liang 2021-09-15 17:45:37 +08:00
commit 500dfc6ee3
28 changed files with 1365 additions and 83 deletions

View File

@ -0,0 +1,76 @@
Jittor调试技巧
=====================
该文档包含了几种异常情况的调试方法和技巧。
## 爆Nan、Inf
在模型训练的过程中可能因为数值不稳定而出现Nan或者Inf为了帮助您定位出现nan的代码您可以设置如下环境变量
```bash
export JT_CHECK_NAN=1
export trace_py_var=3
```
其中,环境变量`JT_CHECK_NAN=1`的用途是:当算子的输出出现异常浮点数时,自动报错并停止程序,环境变量`trace_py_var=3`的用途是输出算子对应的Python代码行数3代表输出的详细等级为最高等级。
需要注意的是开启这两个特性之后jittor速度会大幅下降并且触发重编译请不要在训练环境或者生产环境开启该模式也不要长时间开启该模式。
## 错误信息定位不准确
Jittor框架默认采用延迟执行Lazy execution的方式进行加速算子的执行和创建是不同步的这可能导致报错信息定位不准确您可以手动关闭延迟执行采取立刻执行eager execution的模式使用如下环境变量即可
```bash
export lazy_execution=0
```
或者在python代码中通过flag关闭
```python
jt.flags.lazy_execution=0
```
## 内存不足
当您发现Jittor由于内存相关问题无法运行时Jittor会向您报告内存使用情况内存不足可能有两种情况
1. 训练模型过大,一个迭代就崩溃报错。
2. 多次迭代的过程中,内存占用不断增长,直到最后内存耗尽报错。
**对于第一种情况** ,您可能需要调整模型或者数据大小,或者使用[多卡训练](jittor.mpi)此外您还可以在每个迭代内部让Jittor强制回收内存
```python
for ...:
...
jt.sync_all()
jt.gc()
```
如果您使用到了CUDA和卷积还有可能是卷积消耗的临时空间过大在这种情况下可以关闭cudnn的临时内存申请请将如下代码插入到最开始
```python
jt.cudnn.set_max_workspace_ratio(0.0)
```
**对于第二种情况**,可能是存在内存内存泄漏,请检查您是否存在全局变量没有释放,或者全局变量没有停止梯度,导致计算图不断增加,检查方法如下,您可以在每个迭代内部,插入如下调试代码:
```python
for ...:
...
jt.sync_all()
jt.display_memory_info()
```
Jittor会输出内存消耗以及计算图的大小`lived_var,lived_op`,以及用户持有的变量数`hold_var`, 如果计算图规模不断增大请检查代码或者提交github issue联系我们并且附上错误日志和代码复现脚本。
## 段错误
如果Jittor出现了段错误建议您将错误提交github issue联系我们并且附上错误日志和代码复现脚本。您也可以使用如下环境变量对程序以及框架进行诊断
```bash
export debug=1
export gdb_attach=1
```
其中,环境变量`debug=1`代表开启jittor的debug模式性能会大幅下降但会保留调试信息`gdb_attach=1`将会自动将gdb贴在jittor的主进程上方便您进行单步调试。关于gdb的使用您可以参考[GDB Cheat Sheet](https://darkdust.net/files/GDB%20Cheat%20Sheet.pdf)

View File

@ -45,7 +45,8 @@ language = 'zh_CN'
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
'recommonmark',
# 'recommonmark',
'myst_parser',
'sphinx.ext.autodoc',
# Auto-generate section labels.
'sphinx.ext.autosectionlabel',

View File

@ -32,10 +32,22 @@
jittor.loss3d
.. toctree::
:maxdepth: 2
:caption: 计图模型库:
JDet
segmentation-jittor
InstanceSegmentation-jittor
gan-jittor
PointCloudLib
jrender
.. toctree::
:maxdepth: 1
:caption: 其他:
Jittor调试技巧
教程 <https://cg.cs.tsinghua.edu.cn/jittor/tutorial/>
Indices and tables

View File

@ -1,7 +1,37 @@
jittor.mpi
=====================
这里是Jittor的MPI模块的API文档您可以通过`from jittor import mpi`来获取该模块。
计图分布式基于MPIMessage Passing Interface本文档主要阐述使用计图MPI进行多卡和分布式训练的教程。
## 计图MPI安装
计图依赖`OpenMPI`,用户可以使用如下命令安装`OpenMPI`
```bash
sudo apt install openmpi-bin openmpi-common libopenmpi-dev
```
计图会自动检测环境变量中是否包含`mpicc`,如果计图成功的检测到了`mpicc`,那么会输出如下信息:
```
[i 0502 14:09:55.758481 24 __init__.py:203] Found mpicc(1.10.2) at /usr/bin/mpicc
```
如果计图没有在环境变量中找到mpi用户也可以手动指定mpicc的路径告诉计图添加环境变量即可`export mpicc_path=/you/mpicc/path`
`OpenMPI`安装完成以后,用户无需修改代码,需要做的仅仅是修改启动命令行,计图就会用数据并行的方式自动完成并行操作。
```bash
# 单卡训练代码
python3.7 -m jittor.test.test_resnet
# 分布式多卡训练代码
mpirun -np 4 python3.7 -m jittor.test.test_resnet
# 指定特定显卡的多卡训练代码
CUDA_VISIBLE_DEVICES="2,3" mpirun -np 2 python3.7 -m jittor.test.test_resnet
```
这种便捷性的背后是计图的分布式算子的支撑计图支持的mpi算子后端会使用nccl进行进一步的加速。计图所有分布式算法的开发均在Python前端完成这让分布式算法的灵活度增强开发分布式算法的难度也大大降低。
## 如何从单卡代码适配多卡代码
@ -11,6 +41,8 @@ jittor.mpi
* jittor.nn.BatchNorm* 同步batch norm
* jittor.dataset 自动数据并行
用户在使用MPI进行分布式训练时计图内部的Dataset类会自动并行分发数据需要注意的是Dataset类中设置的Batch size是**所有节点的batch size之和**也就是总batch size 不是单个节点接收到的batch size。
大部分情况下,单卡训练的代码可以直接使用`mpirun`实现分布式多卡运行。 但仍然如下几种情况下,需要对代码进行调整:
1. 对硬盘进行写操作(保存模型,保存曲线)
@ -93,10 +125,30 @@ def val(epoch):
......
```
下面是 jittor 的 mpi api reference.
## MPI接口
下面是 jittor 的 mpi api reference.
目前MPI开放接口如下
* `jt.in_mpi`: 当计图不在MPI环境下时`jt.mpi == False` 用户可以用这个判断是否在mpi环境下。
* `jt.world_rank`: 获取当前进程总数量如果没有用mpi则为1。
* `jt.rank`: 获取当前进程的编号,区间为`0 jt.world_rank-1` 如果没有用mpi则为0。
* `jt.mpi`: 计图的MPI模块。
* `jt.Module.mpi_param_broadcast(root=0)`: 将模块的参数从root节点广播给其他节点。
* `jt.mpi.mpi_reduce(x, op='add', root=0)`: 将所有节点的变量x使用算子opreduce到root节点。如果op是'add'或者'sum'该接口会把所有变量求和如果op是'mean',该接口会取均值。
<img src="https://cg.cs.tsinghua.edu.cn/jittor/images/tutorial/2020-5-2-16-44-distributed/mpi_reduce.png">
* `jt.mpi.mpi_broadcast(x, root=0)`: 将变量x从root节点广播到所有节点。
<img src="https://cg.cs.tsinghua.edu.cn/jittor/images/tutorial/2020-5-2-16-44-distributed/mpi_broadcast.png">
* `jt.mpi.mpi_all_reduce(x, op='add')`: 将所有节点的变量x使用一起reduce并且吧reduce的结果再次广播到所有节点。如果op是'add'或者'sum'该接口会把所有变量求和如果op是'mean',该接口会取均值。
<img src="https://cg.cs.tsinghua.edu.cn/jittor/images/tutorial/2020-5-2-16-44-distributed/mpi_all_reduce.png">
```eval_rst
.. automodule:: jittor_mpi_core
@ -106,3 +158,56 @@ def val(epoch):
:members:
:undoc-members:
```
## 实例MPI实现分布式同步批归一化层
下面的代码是使用计图实现分布式同步批归一化层的实例代码在原来批归一化层的基础上只需增加三行代码就可以实现分布式的batch norm添加的代码如下
```python
# 将均值和方差通过all reduce同步到所有节点
if self.sync and jt.mpi:
xmean = xmean.mpi_all_reduce("mean")
x2mean = x2mean.mpi_all_reduce("mean")
```
> 注:计图内部已经实现了同步的批归一化层,用户不需要自己实现
分布式同步批归一化层的完整代码:
```python
class BatchNorm(Module):
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=None, is_train=True, sync=True):
assert affine == None
self.sync = sync
self.num_features = num_features
self.is_train = is_train
self.eps = eps
self.momentum = momentum
self.weight = init.constant((num_features,), "float32", 1.0)
self.bias = init.constant((num_features,), "float32", 0.0)
self.running_mean = init.constant((num_features,), "float32", 0.0).stop_grad()
self.running_var = init.constant((num_features,), "float32", 1.0).stop_grad()
def execute(self, x):
if self.is_train:
xmean = jt.mean(x, dims=[0,2,3], keepdims=1)
x2mean = jt.mean(x*x, dims=[0,2,3], keepdims=1)
# 将均值和方差通过all reduce同步到所有节点
if self.sync and jt.mpi:
xmean = xmean.mpi_all_reduce("mean")
x2mean = x2mean.mpi_all_reduce("mean")
xvar = x2mean-xmean*xmean
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
self.running_mean += (xmean.sum([0,2,3])-self.running_mean)*self.momentum
self.running_var += (xvar.sum([0,2,3])-self.running_var)*self.momentum
else:
running_mean = self.running_mean.broadcast(x, [0,2,3])
running_var = self.running_var.broadcast(x, [0,2,3])
norm_x = (x-running_mean)/jt.sqrt(running_var+self.eps)
w = self.weight.broadcast(x, [0,2,3])
b = self.bias.broadcast(x, [0,2,3])
return norm_x * w + b
```

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.3.92'
__version__ = '1.2.3.102'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int
@ -69,6 +69,15 @@ def safeunpickle(path):
except:
raise RuntimeError("pytorch need to be installed when load pth format.")
model_dict = torch.load(path, map_location=torch.device('cpu'))
try:
for k, v in model_dict.items():
try:
if not isinstance(v, np.ndarray) and hasattr(v, "cpu"):
model_dict[k] = v.cpu().detach().numpy()
except:
pass
except:
pass
return model_dict
with open(path, "rb") as f:
s = f.read()
@ -332,12 +341,12 @@ def std(x):
return out
Var.std = std
def norm(x, k=2, dim=-1, keepdim=False):
assert k==1 or k==2
if k==1:
def norm(x, p=2, dim=-1, keepdim=False, eps=1e-30):
assert p==1 or p==2
if p==1:
return x.abs().sum(dim, keepdim)
if k==2:
return (x.sqr()).sum(dim, keepdim).maximum(1e-6).sqrt()
if p==2:
return (x.sqr()).sum(dim, keepdim).maximum(eps).sqrt()
Var.norm = norm
origin_reshape = reshape
@ -796,12 +805,30 @@ class Module:
return _uniq(ps)
def named_parameters(self):
ps = self.parameters()
return [ (p.name(), p) for p in ps ]
uniq_set = set()
ps = {}
stack = []
def callback(parents, k, v, n):
stack.append(str(k))
dc = v.__dict__
if isinstance(v, nn.ParameterList):
dc = v.params
for k2, p in dc.items():
if isinstance(k2, str) and k2.startswith("_"): continue
if isinstance(p, Var):
if id(p) in uniq_set: continue
uniq_set.add(id(p))
pname = ".".join(stack[1:]+[str(k2)])
ps[pname] = p
if len(pname) > len(p.name()):
p.name(pname)
def callback_leave(parents, k, v, n):
stack.pop()
self.dfs([], None, callback, callback_leave)
return ps
def state_dict(self):
ps = self.parameters()
return { p.name(): p for p in ps }
return self.named_parameters()
def load_state_dict(self, params):
self.load_parameters(params)
@ -1011,10 +1038,13 @@ Arguments of hook are defined as::
>>> net.save('net.pkl')
>>> net.load('net.pkl')
'''
params = self.parameters()
params = self.named_parameters()
params_dict = {}
for p in params:
params_dict[p.name()] = p.data
for k, v in params.items():
if isinstance(v, Var):
params_dict[k] = v.numpy()
else:
params_dict[k] = v
safepickle(params_dict, path)
def load(self, path: str):

View File

@ -127,6 +127,7 @@ class Dataset(object):
self.epoch_id = 0
self.sampler = None
self._disable_workers = False
self._shuffle_rng = np.random.default_rng(1)
def __getitem__(self, index):
raise NotImplementedError
@ -214,8 +215,8 @@ class Dataset(object):
jittor_utils.cc.init_subprocess()
jt.jt_init_subprocess()
seed = jt.get_seed()
wseed = (seed ^ worker_id) ^ 1234
jt.set_seed(wseed)
wseed = (seed ^ (worker_id*1167)) ^ 1234
jt.set_global_seed(wseed)
# parallel_op_compiler still problematic,
# it is not work on ubuntu 16.04. but worked on ubuntu 20.04
# it seems like the static value of parallel compiler
@ -426,7 +427,10 @@ Example::
elif self.shuffle == False:
index_list = get_order_list(self.total_len)
else:
index_list = get_random_list(self.total_len)
# using _shuffle_rng to generate multiprocess
# consist shuffle list
# index_list = get_random_list(self.total_len)
index_list = self._shuffle_rng.permutation(range(self.total_len))
# scatter index_list for all mpi process
# scatter rule:

View File

@ -7,6 +7,8 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import os
import string
import numpy as np
import gzip
from PIL import Image
@ -94,3 +96,105 @@ class MNIST(Dataset):
for url, md5 in resources:
filename = url.rpartition('/')[2]
download_url_to_local(url, filename, self.data_root, md5)
class EMNIST(Dataset):
'''
Jittor's own class for loading EMNIST dataset.
Args::
[in] data_root(str): your data root.
[in] split(str): one of 'byclass', 'bymerge', 'balanced', 'letters', 'digits', 'mnist'.
[in] train(bool): choose model train or val.
[in] download(bool): Download data automatically if download is Ture.
[in] batch_size(int): Data batch size.
[in] shuffle(bool): Shuffle data if true.
[in] transform(jittor.transform): transform data.
Example::
from jittor.dataset.mnist import EMNIST
train_loader = EMNIST(train=True).set_attrs(batch_size=16, shuffle=True)
for i, (imgs, target) in enumerate(train_loader):
...
'''
_merged_classes = {'c', 'i', 'j', 'k', 'l', 'm', 'o', 'p', 's', 'u', 'v', 'w', 'x', 'y', 'z'}
_all_classes = set(string.digits + string.ascii_letters)
classes_split_dict = {
'byclass': sorted(list(_all_classes)),
'bymerge': sorted(list(_all_classes - _merged_classes)),
'balanced': sorted(list(_all_classes - _merged_classes)),
'letters': ['N/A'] + list(string.ascii_lowercase),
'digits': list(string.digits),
'mnist': list(string.digits),
}
def __init__(self, data_root=dataset_root+"/emnist_data/",
split='byclass',
train=True,
download=True,
batch_size = 16,
shuffle = False,
transform=None):
# if you want to test resnet etc you should set input_channel = 3, because the net set 3 as the input dimensions
super().__init__()
self.data_root = data_root
self.is_train = train
self.transform = transform
self.batch_size = batch_size
self.shuffle = shuffle
if download == True:
self.download_url()
data_root = os.path.join(data_root, "gzip")
filesname = [
f"emnist-{split}-train-images-idx3-ubyte.gz",
f"emnist-{split}-t10k-images-idx3-ubyte.gz",
f"emnist-{split}-train-labels-idx1-ubyte.gz",
f"emnist-{split}-t10k-labels-idx1-ubyte.gz"
]
for i in range(4):
filesname[i] = os.path.join(data_root, filesname[i])
self.mnist = {}
if self.is_train:
with gzip.open(filesname[0], 'rb') as f:
self.mnist["images"] = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1,28, 28).transpose(0,2,1)
with gzip.open(filesname[2], 'rb') as f:
self.mnist["labels"] = np.frombuffer(f.read(), np.uint8, offset=8)
else:
with gzip.open(filesname[1], 'rb') as f:
self.mnist["images"] = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1,28, 28).transpose(0,2,1)
with gzip.open(filesname[3], 'rb') as f:
self.mnist["labels"] = np.frombuffer(f.read(), np.uint8, offset=8)
assert(self.mnist["images"].shape[0] == self.mnist["labels"].shape[0])
self.total_len = self.mnist["images"].shape[0]
# this function must be called
self.set_attrs(total_len = self.total_len)
def __getitem__(self, index):
img = Image.fromarray(self.mnist['images'][index]).convert('RGB')
if self.transform:
img = self.transform(img)
return trans.to_tensor(img), self.mnist['labels'][index]
def download_url(self):
'''
Download mnist data set function, this function will be called when download is True.
'''
resources = [
("https://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zip", "58c8d27c78d21e728a6bc7b3cc06412e"),
]
for url, md5 in resources:
filename = "emnist.zip"
download_url_to_local(url, filename, self.data_root, md5)
import zipfile
zf = zipfile.ZipFile(os.path.join(self.data_root, filename))
try:
zf.extractall(path=self.data_root)
except RuntimeError as e:
print(e)
raise
zf.close()

View File

@ -0,0 +1,125 @@
// ***************************************************************
// Copyright (c) 2021 Jittor. All Rights Reserved.
// Maintainers:
// Guoye Yang <498731903@qq.com>
// Dun Liang <randonlang@gmail.com>.
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include <algorithm>
#include "var.h"
#include "cub_cumsum_op.h"
#include <vector>
#include "executor.h"
#include "ops/op_register.h"
#ifdef JIT_cuda
#include <cub/cub.cuh>
#include <cub/block/block_scan.cuh>
#include <thrust/iterator/reverse_iterator.h>
#endif
namespace jittor {
#ifndef JIT
static auto make_cub_cumsum = get_op_info("cub_cumsum")
.get_constructor<VarPtr, Var*, bool>();
CubCumsumOp::CubCumsumOp(Var* x, bool reverse) : x(x),reverse(reverse) {
flags.set(NodeFlags::_cpu, 0);
flags.set(NodeFlags::_cuda, 1);
y = create_output(nullptr, x->dtype());
}
void CubCumsumOp::infer_shape() {
ASSERT(x->shape.size() == 1 || x->shape.size() == 2); //TODO:support batch_cumsum
y->set_shape(x->shape);
}
void CubCumsumOp::jit_prepare(JK& jk) {
jk << _CS("[Tx:") << x->dtype();
jk << _CS("][Ty:") << y->dtype();
jk << _CS("][reverse:") << reverse;
jk << _CS("]");
}
VarPtr CubCumsumOp::grad(Var* out, Var* dout, Var* v, int v_index) {
return make_cub_cumsum(dout, !reverse);
// return ArgsortOp::get_grad(out, dout, v, v_index, v->shape.size()-1, y);
}
#else // JIT
#ifdef JIT_cuda
#define ITEMS_PER_THREAD 4
#define BLOCK_THREADS 1024
__global__ void BlockScanKernel(Tx* __restrict__ xp, Ty* __restrict__ yp, int batch_num, int num_items) {
typedef cub::BlockScan<Tx, BLOCK_THREADS> BlockScanT;
__shared__ typename BlockScanT::TempStorage temp_storage;
int batch_id = blockIdx.x;
int offset = threadIdx.x * ITEMS_PER_THREAD;
for (int block_offset = offset; block_offset < num_items; block_offset += BLOCK_THREADS * ITEMS_PER_THREAD) {
int items = ITEMS_PER_THREAD;
if (block_offset + ITEMS_PER_THREAD > num_items) {
items = num_items - block_offset;
}
Tx thread_data[ITEMS_PER_THREAD];
#pragma unroll
for (int i = 0; i < ITEMS_PER_THREAD; ++i) {
if (i<items)
#if reverse
thread_data[i] = xp[batch_id * num_items + (num_items - 1 - (block_offset + i))];
#else
thread_data[i] = xp[batch_id * num_items + block_offset + i];
#endif
}
BlockScanT(temp_storage).InclusiveSum(thread_data, thread_data);
__syncthreads();
#pragma unroll
for (int i = 0; i < ITEMS_PER_THREAD; ++i) {
if (i<items)
#if reverse
yp[batch_id * num_items + (num_items - 1 - (block_offset + i))] = thread_data[i];
#else
yp[batch_id * num_items + block_offset + i] = thread_data[i];
#endif
}
}
}
void CubCumsumOp::jit_run() {
auto* __restrict__ xp = x->ptr<Tx>();
auto* __restrict__ yp = y->ptr<Ty>();
if (x->shape.size() == 1){
int num_items = x->shape[0];
// Determine temporary device storage requirements for inclusive prefix sum
void *d_temp_storage = NULL;
size_t temp_storage_bytes = 0, temp_storage_allocation;
cub::DeviceScan::InclusiveSum(NULL, temp_storage_bytes, xp, yp, num_items);
d_temp_storage = exe.temp_allocator->alloc(temp_storage_bytes, temp_storage_allocation);
// Allocate temporary storage for inclusive prefix sum
// cudaMalloc(&d_temp_storage, temp_storage_bytes);
// Run inclusive prefix sum
if (reverse) {
auto xp_ = thrust::make_reverse_iterator(xp + num_items);
auto yp_ = thrust::make_reverse_iterator(yp + num_items);
cub::DeviceScan::InclusiveSum(d_temp_storage, temp_storage_bytes, xp_, yp_, num_items);
} else {
cub::DeviceScan::InclusiveSum(d_temp_storage, temp_storage_bytes, xp, yp, num_items);
}
// yp <-- [8, 14, 21, 26, 29, 29, 38]
exe.temp_allocator->free(d_temp_storage, temp_storage_bytes, temp_storage_allocation);
} else {
int batch_num = x->shape[0];
int num_items = x->shape[1];
BlockScanKernel<<<batch_num, BLOCK_THREADS>>>(xp, yp, batch_num, num_items);
}
}
#endif // JIT_cuda
#endif // JIT
} // jittor

View File

@ -0,0 +1,28 @@
// ***************************************************************
// Copyright (c) 2021 Jittor. All Rights Reserved.
// Maintainers:
// Guoye Yang <498731903@qq.com>
// Dun Liang <randonlang@gmail.com>.
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#pragma once
#include "op.h"
namespace jittor {
struct CubCumsumOp : Op {
Var* x, * y;
bool reverse;
CubCumsumOp(Var* x, bool reverse=false);
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
void infer_shape() override;
const char* name() const override { return "cub_cumsum"; }
DECLARE_jit_run;
};
} // jittor

View File

@ -20,6 +20,8 @@ namespace jittor {
CublasMatmulOp::CublasMatmulOp(Var* a, Var* b, bool trans_a, bool trans_b)
: a(a), b(b), trans_a(trans_a), trans_b(trans_b) {
flags.set(NodeFlags::_cuda, 1);
flags.set(NodeFlags::_cpu, 0);
// TODO: support int8 * int8
ASSERT(a->dtype().is_float() && b->dtype().is_float()) << "type of two inputs should be the same";
// TODO: support diffrent input type
@ -51,7 +53,6 @@ void CublasMatmulOp::jit_prepare(JK& jk) {
}
#else // JIT
#ifdef JIT_cpu
#pragma clang diagnostic ignored "-Wtautological-compare"
void CublasMatmulOp::jit_run() {
cublasHandle_t& handle_ = cublas_handle;
@ -78,7 +79,6 @@ void CublasMatmulOp::jit_run() {
a->ptr<T>(), '@Trans_a' == 'N' ? m : n, &beta,
c->ptr<T>(), k));
}
#endif
#endif // JIT
} // jittor

View File

@ -19,11 +19,12 @@ struct cublas_initer {
inline cublas_initer() {
if (!get_device_count()) return;
checkCudaErrors(cublasCreate(&cublas_handle));
LOGv << "cublasCreate finished";
LOGv << "cublasCreate finished" << (void*)cublas_handle;
}
inline ~cublas_initer() {
if (!get_device_count()) return;
LOGv << "cublasDestroy:" << (void*)cublas_handle;
checkCudaErrors(cublasDestroy(cublas_handle));
LOGv << "cublasDestroy finished";
}

View File

@ -100,15 +100,15 @@ def repeat(x, *shape):
x_shape = (len_shape - len_x_shape) * [1] + x.shape
x = x.broadcast(x_shape)
elif len_x_shape > len_shape:
rep_shape = (len_x_shape - len_shape) * [1] + shape
#TODO if input.shape[i]=1, no add [1]
rep_shape = (len_x_shape - len_shape) * [1] + list(shape)
reshape_shape = []
broadcast_shape = []
for x_s,r_s in zip(x_shape,rep_shape):
reshape_shape.append(1)
if r_s != 1:
reshape_shape.append(1)
broadcast_shape.append(r_s)
reshape_shape.append(x_s)
broadcast_shape.append(r_s)
broadcast_shape.append(1)
x = x.reshape(reshape_shape)
@ -344,7 +344,7 @@ def cross(input, other, dim=-1):
return jt.contrib.concat([a1.unsqueeze(dim),a2.unsqueeze(dim),a3.unsqueeze(dim)], dim=dim)
jt.Var.cross = cross
def normalize(input, p=2, dim=1, eps=1e-12):
def normalize(input, p=2, dim=1, eps=1e-30):
r'''
Performs L_p normalization of inputs over specified dimension.
@ -376,9 +376,7 @@ def normalize(input, p=2, dim=1, eps=1e-12):
[0.02647221 0.59484214 0.80340654]
[0.6910677 0.58067477 0.4303977 ]]
'''
assert p == 2
if p == 2:
return input / jt.maximum(input.sqr().sum(dim,True).sqrt(), eps)
return input / input.norm(p, dim, True, eps)
jt.Var.normalize = normalize
def unbind(x, dim=0):
@ -661,32 +659,66 @@ def _prod(x,dim=0):
return jt.exp(x)
def cumsum_forward(np, data):
a = data['inputs'][0]
b = data['outputs'][0]
np.cumsum(a, axis=1, out=b)
def numpy_cumsum(x, dim=None):
def cumsum_forward(np, data):
dim = data['inputs'][1].item()
a = data['inputs'][0]
b = data['outputs'][0]
np.cumsum(a, axis=dim, out=b)
def cumsum_backward(np, data):
dout = data['dout']
out = data['outputs'][0]
np.cumsum(dout[:, ::-1], axis=1, out=out)
np.copyto(out, out[:, ::-1])
def cumsum_backward(np, data):
dim = data['inputs'][1].item()
dout = data['dout']
out = data['outputs'][0]
np.cumsum(dout[..., ::-1], axis=dim, out=out)
np.copyto(out, out[..., ::-1])
if (dim == None):
dim = -1
assert(dim >= -1 and dim < len(x.shape))
dim_var = jt.array([dim],dtype=int)
return jt.numpy_code(x.shape, x.dtype, [x, dim_var.detach()], cumsum_forward, [cumsum_backward])
def cub_cumsum(x, dim=None):
if (dim == None):
dim = -1
assert(dim >= -1 and dim < len(x.shape))
shape = x.shape
if (dim != -1 and dim != len(shape) - 1):
order = range(len(shape))
order[dim], order[-1] = order[-1], order[dim]
shape[dim], shape[-1] = shape[-1], shape[dim]
x = x.permute(order)
if (len(shape) > 2):
x = x.reshape([-1, shape[-1]])
x = jt.compile_extern.cub_ops.cub_cumsum(x)
if (len(shape) > 2):
x = x.reshape(shape)
if (dim != -1 and dim != len(shape) - 1):
x = x.permute(order)
return x
def cumsum(x, dim=None):
'''
Parameters:
-----------
x: [batch_size, N], jt.var
x: jt.var
dim: int
Returns:
--------
the cumulative sum of x
the cumulative sum in dim of x
'''
return jt.numpy_code(x.shape, x.dtype, [x], cumsum_forward, [cumsum_backward])
if (dim == None):
dim = -1
assert(dim >= -1 and dim < len(x.shape))
if jt.has_cuda:
return cub_cumsum(x, dim)
else:
return numpy_cumsum(x, dim)
jt.Var.cumsum = cumsum
def cumprod(x,dim=0):
def cumprod(x,dim=None):
x = jt.log(x)
x = cumsum(x,dim=dim)
return jt.exp(x)
@ -749,26 +781,26 @@ def triu_(x,diagonal=0):
jt.Var.triu_ = triu_
def print_tree(now, max_memory_size, prefix1, prefix2, build_by):
def format_size(s):
def format_size(s, end='B'):
if (s < 1024):
s = str(s)
return s + ' B'
return s + ' '+end
if (s < 1024*1024):
s = format(s/1024, '.2f')
return s + ' KB'
return s + ' K'+end
if (s < 1024*1024*1024):
s = format(s/1024/1024, '.2f')
return s + ' MB'
return s + ' M'+end
s = format(s/1024/1024/1024, '.2f')
return s + ' GB'
return s + ' G'+end
out = ''
tab = ' '
out += prefix1+now['name']+'('+now['type']+')\n'
out += prefix2+'['+format_size(now['size'])+'; '+format(now['size']/max_memory_size*100, '.2f')+'%]\n'
out += prefix2+'['+format_size(now['size'])+'; '+format(now['size']/max_memory_size*100, '.2f')+'%; cnt:'+format_size(now['cnt'],'') + ']\n'
if (build_by == 0):
for p in now['path']:
out += prefix2+p+'\n'
@ -834,7 +866,7 @@ Output::
vars_ = vars_[1:]
for v_ in vars_:
v__ = v_.split(div2)
var = {'size':int(v__[1]), 'stack':[]}
var = {'size':int(v__[1]), 'stack':[], 'cnt':1}
v__ = v__[2:-1]
for s_ in v__:
s__ = s_.split(div3)
@ -842,7 +874,7 @@ Output::
var['stack'].append(s)
vars.append(var)
if (build_by == 0): # build tree by name
tree = {'name':'root', "children":[], 'size':0, 'path':[], 'type':''}
tree = {'name':'root', "children":[], 'size':0, 'cnt':1, 'path':[], 'type':''}
def find_child(now, key):
for c in now['children']:
@ -852,6 +884,7 @@ Output::
for v in vars:
now = tree
now['size'] += v['size']
now['cnt'] += v['cnt']
for s in v['stack']:
ch = find_child(now, s['name'])
if (ch is not None):
@ -860,12 +893,13 @@ Output::
assert(ch['type']==s['type'])
now = ch
now['size'] += v['size']
now['cnt'] += v['cnt']
else:
now_ = {'name':s['name'], "children":[], 'size':v['size'], 'path':[s['path']], 'type':s['type']}
now_ = {'name':s['name'], "children":[], 'size':v['size'], 'cnt':v['cnt'], 'path':[s['path']], 'type':s['type']}
now['children'].append(now_)
now = now_
elif (build_by == 1): # build tree by path
tree = {'name':'root', "children":[], 'size':0, 'path':'_root_', 'type':''}
tree = {'name':'root', "children":[], 'size':0, 'cnt':0, 'path':'_root_', 'type':''}
def find_child(now, key):
for c in now['children']:
@ -875,13 +909,15 @@ Output::
for v in vars:
now = tree
now['size'] += v['size']
now['cnt'] += v['cnt']
for s in v['stack']:
ch = find_child(now, s['path'])
if (ch is not None):
now = ch
now['size'] += v['size']
now['cnt'] += v['cnt']
else:
now_ = {'name':s['name'], "children":[], 'size':v['size'], 'path':s['path'], 'type':s['type']}
now_ = {'name':s['name'], "children":[], 'size':v['size'], 'cnt':v['cnt'], 'path':s['path'], 'type':s['type']}
now['children'].append(now_)
now = now_
else:
@ -1063,14 +1099,17 @@ def randperm(n, dtype="int32"):
index, _ = jt.argsort(key)
return index.cast(dtype)
def set_global_seed(seed):
def set_global_seed(seed, different_seed_for_mpi=True):
''' Sets the seeds of the random number generators of Python, numpy and jittor,
simultaneously.
.. note::
Jittor also gurantees each worker of jittor.dataset.Dataset to hold a different seed,
which is global_seed ^ worker_id ^ 1234.
also gurantees each process hold a different seed which using mpi,
which is (global_seed ^ (worker_id*1167)) ^ 1234 + jt.rank * 2591
'''
if (different_seed_for_mpi):
seed = seed + jt.rank * 2591
import random
random.seed(seed)
jt.set_seed(seed)
@ -1081,6 +1120,9 @@ def set_global_seed(seed):
except:
pass
import time
set_global_seed(int(time.time() * 1000000) % 100000007)
def searchsorted(sorted, values, right=False):
"""
Find the indices from the innermost dimension of `sorted` for each `values`.
@ -1269,3 +1311,348 @@ jt.Var.roll = roll
def safe_log(x):
return jt.safe_clip(x, 1e-30, 1e30).log()
jt.Var.safe_log = safe_log
class _CTCLossFunction(jt.Function):
def execute(self, log_probs, targets, input_lengths, target_lengths, blank=0, zero_infinity=False):
self.blank = blank
T, N, C = log_probs.shape
_N, S = targets.shape
assert _N == N
log_alpha = jt.full([T,N,S*2+1], -1e30)
result = jt.empty((N,))
jt.code([log_probs, targets, input_lengths, target_lengths], [log_alpha, result], cpu_src=f"""
constexpr int blank = {blank};
for (int i=0; i<in0_shape1; i++) {{
int input_len = @in2(i);
int target_len = @in3(i);
@out0(0,i,0) = @in0(0,i,blank);
if (target_len)
@out0(0,i,1) = @in0(0,i,@in1(i,0));
for (int j=1; j<input_len; j++)
for (int k=0; k<target_len*2+1; k++) {{
int target = k%2 ? @in1(i,k/2) : blank;
int target_2 = target;
if (k>1 && k%2) target_2 = @in1(i,k/2-1);
out_type l1 = @out0(j-1,i,k);
out_type l2 = -1e30;
if (k>0) l2 = @out0(j-1,i,k-1);
out_type l3 = -1e30;
if (k>1 && target_2 != target)
l3 = @out0(j-1,i,k-2);
out_type m = std::max(l1, std::max(l2, l3));
@out0(j,i,k) = std::log(
std::exp(l1-m) +
std::exp(l2-m) +
std::exp(l3-m)
) + m + @in0(j,i,target);
}}
if (input_len==0)
@out1(i) = @out0(0,i,0);
else {{
out_type l1 = @out0(input_len-1, i, target_len*2);
out_type l2 = -1e30;
if (target_len)
l2 = @out0(input_len-1, i, target_len*2-1);
out_type m = std::max(l1, l2);
out_type log_likelihood = std::log(std::exp(l1-m)+std::exp(l2-m))+m;
@out1(i) = -log_likelihood;
}}
}}
""", cuda_src=f"""
__global__ void kernel(@ARGS_DEF) {{
@PRECALC;
constexpr int blank = {blank};
for (int i=blockIdx.x; i<in0_shape1; i+=gridDim.x) {{
int input_len = @in2(i);
int target_len = @in3(i);
@out0(0,i,0) = @in0(0,i,blank);
if (target_len)
@out0(0,i,1) = @in0(0,i,@in1(i,0));
for (int j=1; j<input_len; j++)
for (int k=threadIdx.x; k-threadIdx.x<target_len*2+1; k+=blockDim.x) {{
__syncthreads();
if (k>=target_len*2+1)
continue;
int target = k%2 ? @in1(i,k/2) : blank;
int target_2 = target;
if (k>1 && k%2) target_2 = @in1(i,k/2-1);
out_type l1 = @out0(j-1,i,k);
out_type l2 = -1e30;
if (k>0) l2 = @out0(j-1,i,k-1);
out_type l3 = -1e30;
if (k>1 && target_2 != target)
l3 = @out0(j-1,i,k-2);
out_type m = ::max(l1, ::max(l2, l3));
@out0(j,i,k) = ::log(
::exp(l1-m) +
::exp(l2-m) +
::exp(l3-m)
) + m + @in0(j,i,target);
}}
__syncthreads();
if (input_len==0)
@out1(i) = @out0(0,i,0);
else {{
out_type l1 = @out0(input_len-1, i, target_len*2);
out_type l2 = -1e30;
if (target_len)
l2 = @out0(input_len-1, i, target_len*2-1);
out_type m = ::max(l1, l2);
out_type log_likelihood = ::log(::exp(l1-m)+::exp(l2-m))+m;
@out1(i) = -log_likelihood;
}}
}}
}}
kernel<<<std::min(in0_shape1, 1024), std::min(in1_shape1*2+1, 1024)>>>(@ARGS);
""")
self.saved_var = [log_probs, targets, input_lengths, target_lengths, log_alpha, result]
return result
def grad(self, dout):
blank = self.blank
inputs = self.saved_var + [dout]
dlog_probs = jt.zeros_like(inputs[0])
dlog_alpha = jt.zeros_like(inputs[4])
jt.code(inputs, [dlog_probs, dlog_alpha], cpu_src=f"""
constexpr int blank = {blank};
for (int i=0; i<in0_shape1; i++) {{
int input_len = @in2(i);
int target_len = @in3(i);
if (input_len==0)
// write out1 --> read in6
// out1(i) = out0(0,i,0);
@out1(0,i,0) = @in6(i);
else {{
out_type l1 = @in4(input_len-1, i, target_len*2);
out_type l2 = -1e30;
if (target_len)
l2 = @in4(input_len-1, i, target_len*2-1);
out_type m = std::max(l1, l2);
// out_type log_likelihood = std::log(std::exp(l1-m)+std::exp(l2-m))+m;
// out1(i) = -log_likelihood;
out_type l1_exp = std::exp(l1-m);
out_type l2_exp = std::exp(l2-m);
out_type sumexp = l1_exp + l2_exp;
out_type dlog_likelihood = -@in6(i);
out_type dl1 = dlog_likelihood * l1_exp / sumexp;
out_type dl2 = dlog_likelihood * l2_exp / sumexp;
@out1(input_len-1, i, target_len*2) = dl1;
if (target_len)
@out1(input_len-1, i, target_len*2-1) = dl2;
}}
for (int j=input_len-1; j>0; j--)
for (int k=0; k<target_len*2+1; k++) {{
int target = k%2 ? @in1(i,k/2) : blank;
int target_2 = target;
if (k>1 && k%2) target_2 = @in1(i,k/2-1);
out_type l1 = @in4(j-1,i,k);
out_type l2 = -1e30;
if (k>0) l2 = @in4(j-1,i,k-1);
out_type l3 = -1e30;
if (k>1 && target_2 != target)
l3 = @in4(j-1,i,k-2);
out_type m = std::max(l1, std::max(l2, l3));
out_type l1_exp = std::exp(l1-m);
out_type l2_exp = std::exp(l2-m);
out_type l3_exp = std::exp(l3-m);
out_type sumexp = l1_exp + l2_exp + l3_exp;
out_type dalpha = @out1(j,i,k);
@out0(j,i,target) += dalpha;
@out1(j-1,i,k) += dalpha * l1_exp / sumexp;
if (k>0)
@out1(j-1,i,k-1) += dalpha * l2_exp / sumexp;
if (k>1 && target_2 != target)
@out1(j-1,i,k-2) += dalpha * l3_exp / sumexp;
}}
// read in0 -> white out0
// write out0 ->read out1
// out0(0,i,0) = in0(0,i,blank);
@out0(0,i,blank) += @out1(0,i,0);
if (target_len)
@out0(0,i,@in1(i,0)) += @out1(0,i,1);
}}
""", cuda_src=f"""
__global__ void kernel(@ARGS_DEF) {{
@PRECALC;
constexpr int blank = {blank};
for (int i=blockIdx.x; i<in0_shape1; i+=gridDim.x) {{
int input_len = @in2(i);
int target_len = @in3(i);
if (input_len==0)
// write out1 --> read in6
// out1(i) = out0(0,i,0);
@out1(0,i,0) = @in6(i);
else {{
out_type l1 = @in4(input_len-1, i, target_len*2);
out_type l2 = -1e30;
if (target_len)
l2 = @in4(input_len-1, i, target_len*2-1);
out_type m = ::max(l1, l2);
// out_type log_likelihood = ::log(::exp(l1-m)+::exp(l2-m))+m;
// out1(i) = -log_likelihood;
out_type l1_exp = ::exp(l1-m);
out_type l2_exp = ::exp(l2-m);
out_type sumexp = l1_exp + l2_exp;
out_type dlog_likelihood = -@in6(i);
out_type dl1 = dlog_likelihood * l1_exp / sumexp;
out_type dl2 = dlog_likelihood * l2_exp / sumexp;
@out1(input_len-1, i, target_len*2) = dl1;
if (target_len)
@out1(input_len-1, i, target_len*2-1) = dl2;
}}
for (int j=input_len-1; j>0; j--)
for (int k=threadIdx.x; k-threadIdx.x<target_len*2+1; k+=blockDim.x) {{
__syncthreads();
if (k>=target_len*2+1)
continue;
int target = k%2 ? @in1(i,k/2) : blank;
int target_2 = target;
if (k>1 && k%2) target_2 = @in1(i,k/2-1);
out_type l1 = @in4(j-1,i,k);
out_type l2 = -1e30;
if (k>0) l2 = @in4(j-1,i,k-1);
out_type l3 = -1e30;
if (k>1 && target_2 != target)
l3 = @in4(j-1,i,k-2);
out_type m = ::max(l1, ::max(l2, l3));
out_type l1_exp = ::exp(l1-m);
out_type l2_exp = ::exp(l2-m);
out_type l3_exp = ::exp(l3-m);
out_type sumexp = l1_exp + l2_exp + l3_exp;
out_type dalpha = @out1(j,i,k);
atomicAdd(&@out0(j,i,target), dalpha);
atomicAdd(&@out1(j-1,i,k), dalpha * l1_exp / sumexp);
if (k>0)
atomicAdd(&@out1(j-1,i,k-1), dalpha * l2_exp / sumexp);
if (k>1 && target_2 != target)
atomicAdd(&@out1(j-1,i,k-2), dalpha * l3_exp / sumexp);
}}
// read in0 -> white out0
// write out0 ->read out1
// out0(0,i,0) = in0(0,i,blank);
__syncthreads();
if (threadIdx.x==0) {{
@out0(0,i,blank) += @out1(0,i,0);
if (target_len)
@out0(0,i,@in1(i,0)) += @out1(0,i,1);
}}
}}
}}
kernel<<<std::min(in0_shape1, 1024), std::min(in1_shape1*2+1, 1024)>>>(@ARGS);
""")
return (dlog_probs,)
def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean', zero_infinity=False):
'''The Connectionist Temporal Classification loss.
Reference:
A. Graves et al.: Connectionist Temporal Classification:
Labelling Unsegmented Sequence Data with Recurrent Neural Networks:
https://www.cs.toronto.edu/~graves/icml_2006.pdf
Input:
log_probs: shape is [T, N, C], T is the sequence length, N is the batch size, C is the class number.
targets: shape is [N, S], N is the batch size, S is the target sequence length, element should between [0,C).
input_lengths: shape is [N], which represents the length of input, element should between [0,T].
target_lengths: shape is N, which represents the length of target, element should between [0,S].
blank (int, default 0): blank label index
reduction (string): reduce batch loss,
if reduction is none, it will return (N,) array,
if reduction is mean or sum, it will return one scalar
zero_infinity (bool, default False):
zero_infinity for grad
Example:
import jittor as jt
T = 50 # Input sequence length
C = 20 # Number of classes (including blank)
N = 16 # Batch size
S = 30 # Target sequence length of longest target in batch (padding length)
S_min = 10 # Minimum target length, for demonstration purposes
input = jt.randn(T, N, C).log_softmax(2)
# Initialize random batch of targets (0 = blank, 1:C = classes)
target = jt.randint(low=1, high=C, shape=(N, S), dtype=jt.int)
input_lengths = jt.full((N,), T, dtype=jt.int)
target_lengths = jt.randint(low=S_min, high=S+1, shape=(N,), dtype=jt.int)
loss = jt.ctc_loss(input, target, input_lengths, target_lengths)
dinput = jt.grad(loss, input)
'''
result = _CTCLossFunction.apply(log_probs, targets, input_lengths, target_lengths, blank, zero_infinity)
if reduction=="mean":
return result.mean()
elif reduction=="sum":
return result.sum()
assert reduction=="none"
return result
class CTCLoss(jt.Module):
'''The Connectionist Temporal Classification loss.
Reference:
A. Graves et al.: Connectionist Temporal Classification:
Labelling Unsegmented Sequence Data with Recurrent Neural Networks:
https://www.cs.toronto.edu/~graves/icml_2006.pdf
Args:
blank (int, default 0): blank label index
reduction (string): reduce batch loss,
if reduction is none, it will return (N,) array,
if reduction is mean or sum, it will return one scalar
zero_infinity (bool, default False):
zero_infinity for grad
Input:
log_probs: shape is [T, N, C], T is the sequence length, N is the batch size, C is the class number.
targets: shape is [N, S], N is the batch size, S is the target sequence length, element should between [0,C).
input_lengths: shape is [N], which represents the length of input, element should between [0,T].
target_lengths: shape is N, which represents the length of target, element should between [0,S].
Example:
import jittor as jt
T = 50 # Input sequence length
C = 20 # Number of classes (including blank)
N = 16 # Batch size
S = 30 # Target sequence length of longest target in batch (padding length)
S_min = 10 # Minimum target length, for demonstration purposes
input = jt.randn(T, N, C).log_softmax(2)
# Initialize random batch of targets (0 = blank, 1:C = classes)
target = jt.randint(low=1, high=C, shape=(N, S), dtype=jt.int)
input_lengths = jt.full((N,), T, dtype=jt.int)
target_lengths = jt.randint(low=S_min, high=S+1, shape=(N,), dtype=jt.int)
ctc_loss = jt.CTCLoss()
loss = ctc_loss(input, target, input_lengths, target_lengths)
dinput = jt.grad(loss, input)
'''
def __init__(self, blank=0, reduction='mean', zero_infinity=False):
self.blank = blank
self.reduction = reduction
self.zero_infinity = zero_infinity
def execute(self, log_probs, targets, input_lengths, target_lengths):
return ctc_loss(log_probs, targets, input_lengths, target_lengths, self.blank, self.reduction, self.zero_infinity)

View File

@ -7,12 +7,12 @@ import math
model_urls = {
'res2net50_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_26w_4s-06e79181.pth',
'res2net50_48w_2s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_48w_2s-afed724a.pth',
'res2net50_14w_8s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_14w_8s-6527dddc.pth',
'res2net50_26w_6s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_26w_6s-19041792.pth',
'res2net50_26w_8s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_26w_8s-2c7c9f12.pth',
'res2net101_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net101_26w_4s-02a759a1.pth',
'res2net50_14w_8s': 'https://cloud.tsinghua.edu.cn/f/2543e4b5646d40a1afa9/?dl=1&fname=/res2net50_14w_8s.pkl',
'res2net50_26w_4s': 'https://cloud.tsinghua.edu.cn/f/927fead9c9884f769d88/?dl=1&fname=/res2net50_26w_4s.pkl',
'res2net50_26w_6s': 'https://cloud.tsinghua.edu.cn/f/067875edf6ca488ba83e/?dl=1&fname=/res2net50_26w_6s.pkl',
'res2net50_26w_8s': 'https://cloud.tsinghua.edu.cn/f/ce1230155a2c4352bf17/?dl=1&fname=/res2net50_26w_8s.pkl',
'res2net50_48w_2s': 'https://cloud.tsinghua.edu.cn/f/b8a4df2b2cb64500b869/?dl=1&fname=/res2net50_48w_2s.pkl',
'res2net101_26w_4s': 'https://cloud.tsinghua.edu.cn/f/b85283bf572649d288bb/?dl=1&fname=/res2net101_26w_4s.pkl',
}

View File

@ -123,7 +123,7 @@ Example::
shape = []
len_c = max(len_a, len_b)
(n, m), (m_, k) = a.shape[-2:], b.shape[-2:]
assert m == m_, f"dimension not match, a.shape:{a.shape}, b.shape:{a.shape}"
assert m == m_, f"dimension not match, a.shape:{a.shape}, b.shape:{b.shape}"
# a: [..., n, m]
# b: [..., m, k]
# cc:[..., n, m, k]
@ -141,7 +141,7 @@ Example::
an = a.shape[ai] if ai>=0 else 1
bn = b.shape[bi] if bi>=0 else 1
if an!=1 and bn!=1:
assert an == bn, f"dimension not match, a.shape:{a.shape}, b.shape:{a.shape}"
assert an == bn, f"dimension not match, a.shape:{a.shape}, b.shape:{b.shape}"
cn = max(an, bn)
shape.append(cn)
shape.extend([n, m, k])
@ -341,16 +341,20 @@ def softmax(x, dim = None):
x = (x-x.max(dim, keepdims=True)).exp()
ret = x / x.sum(dim, keepdims=True)
return ret
jt.Var.softmax = softmax
def log_softmax(x,dim=None):
x = softmax(x,dim=dim)
return jt.log(x)
jt.Var.log_softmax = log_softmax
def log_sigmoid(x):
return jt.log(jt.sigmoid(x))
jt.Var.log_sigmoid = log_sigmoid
def logsumexp(x, dim, keepdim=False):
return x.exp().sum(dim, keepdim).log()
jt.Var.logsumexp = logsumexp
class Identity(Module):
def __init__(self, *args, **kwargs):
@ -1612,7 +1616,8 @@ class Sequential(Module):
return
parents.append(self)
for k,v in self.layers.items():
v.dfs(parents, k, callback, callback_leave)
if isinstance(v, Module):
v.dfs(parents, k, callback, callback_leave)
parents.pop()
if callback_leave:
callback_leave(parents, k, self, n_children)

View File

@ -30,8 +30,10 @@ int current_seed;
EXTERN_LIB list<VarPtr> fetcher;
EXTERN_LIB list<VarPtr> fetcher_to_free;
EXTERN_LIB vector<void(*)()> cleanup_callback;
EXTERN_LIB bool exited;
void cleanup() {
exited = true;
fetcher_to_free.clear();
fetcher.clear();
for (auto cb : cleanup_callback)

View File

@ -17,7 +17,9 @@ void* AlignedAllocator::alloc(size_t size, size_t& allocation) {
#ifndef _WIN32
#ifdef __APPLE__
size += 32-size%32;
#endif
// low version of mac don't have aligned_alloc
return new char[size];
#else
return aligned_alloc(alignment, size);
#else
return _aligned_malloc(size, alignment);
@ -28,8 +30,12 @@ void AlignedAllocator::free(void* mem_ptr, size_t size, const size_t& allocation
#ifdef _WIN32
_aligned_free(mem_ptr);
#else
#ifdef __APPLE__
delete[] (char*)mem_ptr;
#else
::free(mem_ptr);
#endif
#endif
}
} // jittor

View File

@ -241,7 +241,7 @@ void segfault_sigaction(int signal, siginfo_t *si, void *arg) {
return;
}
if (signal == SIGCHLD) {
if (si->si_code != CLD_EXITED && si->si_status != SIGTERM) {
if (si->si_code != CLD_EXITED && si->si_status != SIGTERM && _pid == getpid()) {
LOGe << "Caught SIGCHLD"
<< "si_errno:" << si->si_errno
<< "si_code:" << si->si_code
@ -259,6 +259,7 @@ void segfault_sigaction(int signal, siginfo_t *si, void *arg) {
exited = true;
do_exit();
}
if (exited) do_exit();
std::cerr << "Caught segfault at address " << si->si_addr << ", "
<< "thread_name: '" << thread_name << "', flush log..." << std::endl;
std::cerr.flush();
@ -556,9 +557,11 @@ void system_with_check(const char* cmd, const char* cwd) {
std::thread log_thread(log_main);
#endif
int log_exit = 0;
void log_exiting() {
if (exited) return;
exited = true;
if (log_exit) return;
log_exit = true;
for (auto cb : cleanup_callback)
cb();
cleanup_callback.clear();

View File

@ -7,6 +7,7 @@
import unittest
import jittor as jt
import numpy as np
import os
def expect_error(func):
try:
@ -86,6 +87,17 @@ class TestCore(unittest.TestCase):
c = np.matmul(a, b)
jtc = jt.matmul(jt.array(a), jt.array(b)).data
assert np.all(jtc == c)
def test_save_load_sub_module(self):
class Net(jt.Module):
def __init__(self):
self.conv1 = jt.nn.Conv(3,3,3)
net = Net()
assert list(net.named_parameters().keys()) == ['conv1.weight', 'conv1.bias']
assert list(net.conv1.named_parameters().keys()) == ['weight', 'bias']
pkl_name = os.path.join(jt.flags.cache_path, "sub.pkl")
net.conv1.save(pkl_name)
net.conv1.load(pkl_name)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,103 @@
# ***************************************************************
# Copyright (c) 2021 Jittor. All Rights Reserved.
# Maintainers:
# Guoye Yang <498731903@qq.com>
# Dun Liang <randonlang@gmail.com>.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import unittest
import jittor as jt
import numpy as np
from jittor import compile_extern
if jt.has_cuda:
from jittor.compile_extern import cublas_ops, cudnn_ops, cub_ops
else:
cublas_ops = cudnn_ops = cub_ops = None
def test_forward(shape, dim=None):
x = jt.random(shape)
y = jt.numpy_cumsum(x)
y_ = jt.cub_cumsum(x)
assert(np.allclose(y.data, y_.data))
def test_backward(shape, dim=None):
x = jt.random(shape)
z = jt.random(shape)
y = jt.numpy_cumsum(x)
loss = (y * z).sum()
grad = jt.grad(loss, x)
y_ = jt.cub_cumsum(x)
loss_ = (y_ * z).sum()
grad_ = jt.grad(loss_, x)
assert(np.allclose(grad.data, grad_.data))
class TestCubCumsumOp(unittest.TestCase):
@unittest.skipIf(cub_ops==None, "Not use cub, Skip")
@jt.flag_scope(use_cuda=1)
def test_1d(self):
test_forward([20])
test_forward([3007])
test_forward([3007], 0)
test_forward([3007], -1)
@unittest.skipIf(cub_ops==None, "Not use cub, Skip")
@jt.flag_scope(use_cuda=1)
def test_1d_backward(self):
test_backward([20])
test_backward([3007])
test_backward([3007], 0)
test_backward([3007], -1)
@unittest.skipIf(cub_ops==None, "Not use cub, Skip")
@jt.flag_scope(use_cuda=1)
def test_2d(self):
test_forward([5,5])
test_forward([2000, 3007])
test_forward([2000, 3007], 1)
test_forward([2000, 3007], -1)
@unittest.skipIf(cub_ops==None, "Not use cub, Skip")
@jt.flag_scope(use_cuda=1)
def test_2d_backward(self):
test_backward([5,5])
test_backward([2000, 3007])
test_backward([2000, 3007], 1)
test_backward([2000, 3007], -1)
@unittest.skipIf(cub_ops==None, "Not use cub, Skip")
@jt.flag_scope(use_cuda=1)
def test_nd(self):
test_forward([5,6,7,8], 0)
test_forward([5,6,7,8], 1)
test_forward([5,6,7,8], 2)
test_forward([5,6,7,8], 3)
test_forward([5,6,7,8], -1)
test_forward([16,14,14,2048], 0)
test_forward([16,14,14,2048], 1)
test_forward([16,14,14,2048], 2)
test_forward([16,14,14,2048], 3)
test_forward([16,14,14,2048], -1)
test_forward([16,14,14,2047], 3)
@unittest.skipIf(cub_ops==None, "Not use cub, Skip")
@jt.flag_scope(use_cuda=1)
def test_nd_backward(self):
test_backward([5,6,7,8], 0)
test_backward([5,6,7,8], 1)
test_backward([5,6,7,8], 2)
test_backward([5,6,7,8], 3)
test_backward([5,6,7,8], -1)
test_backward([16,14,14,2048], 0)
test_backward([16,14,14,2048], 1)
test_backward([16,14,14,2048], 2)
test_backward([16,14,14,2048], 3)
test_backward([16,14,14,2048], -1)
test_backward([16,14,14,2047], 3)
if __name__ == "__main__":
unittest.main()

View File

@ -267,6 +267,45 @@ if __name__ == "__main__":
assert "quick exit" in s
@unittest.skipIf(not jt.compile_extern.has_mpi, "no mpi found")
def test_dataset_shuffle_mpi(self):
src = """
import jittor as jt
from jittor.dataset import Dataset
import numpy as np
class YourDataset(Dataset):
def __init__(self):
super().__init__()
self.set_attrs(total_len=160, shuffle=True)
def __getitem__(self, k):
return k
dataset = YourDataset()
dataset.set_attrs(num_workers=2)
for d in dataset:
for a in d:
print("CHECK: ", a.item())
"""
fname = os.path.join(jt.flags.cache_path, "test_dataset_shuffle_mpi.py")
with open(fname, 'w') as f:
f.write(src)
import subprocess as sp
import sys
cmd = sys.executable + " " + fname
mpirun_path = jt.compile_extern.mpicc_path.replace("mpicc", "mpirun")
cmd = mpirun_path + " -np 2 " + cmd
print(cmd)
r = sp.run(cmd, shell=True, stdout=sp.PIPE, stderr=sp.PIPE)
s = r.stdout.decode()
# print(s)
st = set([ l for l in s.splitlines() if l.startswith("CHECK:") ])
assert r.returncode == 0
# print(st)
assert len(st) == 160, len(st)
def test_children_died2(self):
src = """
import jittor as jt

View File

@ -0,0 +1,31 @@
# ***************************************************************
# Copyright (c) 2021 Jittor. All Rights Reserved.
# Maintainers:
# Dun Liang <randonlang@gmail.com>.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import unittest
import jittor as jt
from jittor.dataset.mnist import EMNIST, MNIST
import jittor.transform as transform
@unittest.skipIf(True, f"skip emnist test")
class TestEMNIST(unittest.TestCase):
def test_emnist(self):
import pylab as pl
# emnist_dataset = EMNIST()
emnist_dataset = EMNIST()
for imgs, labels in emnist_dataset:
print(imgs.shape, labels.shape)
print(labels.max(), labels.min())
# imgs = imgs.transpose(0,1,3,2).transpose(1,2,0,3)[0].reshape(28, -1)
imgs = imgs.transpose(1,2,0,3)[0].reshape(28, -1)
print(labels)
pl.imshow(imgs), pl.show()
break
if __name__ == "__main__":
unittest.main()

View File

@ -47,7 +47,17 @@ class TestLoadPth(unittest.TestCase):
jt_out = jt_model(jt_img)
torch_out = torch_model(torch_img)
print(np.max(np.abs(jt_out.fetch_sync() - torch_out.detach().numpy())))
assert np.max(np.abs(jt_out.fetch_sync() - torch_out.detach().numpy())) < 1e-4
assert np.max(np.abs(jt_out.fetch_sync() - torch_out.detach().numpy())) < 1e-3
pth_name = os.path.join(jt.flags.cache_path, "x.pth")
torch.save(torch_model.state_dict, pth_name)
jt_model.load(pth_name)
# output
jt_out = jt_model(jt_img)
# torch_out = torch_model(torch_img)
print(np.max(np.abs(jt_out.fetch_sync() - torch_out.detach().numpy())))
assert np.max(np.abs(jt_out.fetch_sync() - torch_out.detach().numpy())) < 1e-3
if __name__ == "__main__":
unittest.main()

View File

@ -1,3 +1,13 @@
# ***************************************************************
# Copyright (c) 2021 Jittor. All Rights Reserved.
# Maintainers:
# Zheng-Ning Liu <lzhengning@gmail.com>
# Dun Liang <randonlang@gmail.com>.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import unittest
import numpy as np

View File

@ -165,6 +165,8 @@ jt.mkl_ops.mkl_conv(x, w, 1, 1, 2, 2).sync()
n += 1
assert n == 2
assert list(x.keys()) == [0,1]
p = x.parameters()
assert len(p)==0
# def test_res2net(self):
# import jittor.models

View File

@ -159,6 +159,86 @@ class TestPad(unittest.TestCase):
out.detach().numpy(), output.data,
atol=1e-4)
def test_ctc_loss(self):
def check(T,C,N,S,S_min):
jt.set_global_seed(0)
# Initialize random batch of input vectors, for *size = (T,N,C)
input = jt.randn(T, N, C).log_softmax(2)
# input = -jt.ones((T, N, C))
# input[0,0,1] += 0.01
# Initialize random batch of targets (0 = blank, 1:C = classes)
target = jt.randint(low=1, high=C, shape=(N, S), dtype=jt.int)
_input_jt = input
input_lengths = jt.full((N,), T, dtype=jt.int)
target_lengths = jt.randint(low=S_min, high=S+1, shape=(N,), dtype=jt.int)
# ctc_loss = nn.CTCLoss()
loss = jt.ctc_loss(input, target, input_lengths, target_lengths, reduction='none')
_loss_jt = loss
loss_jt = loss.numpy()
input = torch.Tensor(input.numpy()).detach().requires_grad_()
input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
target_lengths = torch.LongTensor(target_lengths.numpy())
input_lengths = torch.LongTensor(input_lengths.numpy())
target = torch.LongTensor(target.numpy())
loss = tnn.CTCLoss(reduction='none')(input, target, input_lengths, target_lengths)
np.testing.assert_allclose(loss.detach().numpy(), loss_jt, rtol=1e-5, atol=1e-5)
dinput_jt = jt.grad(_loss_jt, _input_jt)
dinput_jt.sync()
loss.sum().backward()
# print(input.grad)
# print(dinput_jt)
# print(loss)
def check_gpu_with_cpu(T,C,N,S,S_min):
jt.set_global_seed(1)
# Initialize random batch of input vectors, for *size = (T,N,C)
input = jt.randn(T, N, C).log_softmax(2)
# input = -jt.ones((T, N, C))
# input[0,0,1] += 0.01
# Initialize random batch of targets (0 = blank, 1:C = classes)
target = jt.randint(low=1, high=C, shape=(N, S), dtype=jt.int)
_input_jt = input
input_lengths = jt.full((N,), T, dtype=jt.int)
target_lengths = jt.randint(low=S_min, high=S+1, shape=(N,), dtype=jt.int)
# ctc_loss = nn.CTCLoss()
loss = jt.ctc_loss(input, target, input_lengths, target_lengths, reduction='none')
_loss_jt = loss
loss_jt = loss.numpy()
dinput_jt = jt.grad(_loss_jt, _input_jt)
dinput_jt.sync()
with jt.flag_scope(use_cuda=1):
input = input.copy()
target = target.copy()
input_lengths = input_lengths.copy()
target_lengths = target_lengths.copy()
loss = jt.ctc_loss(input, target, input_lengths, target_lengths, reduction='none')
grad = jt.grad(loss, input)
np.testing.assert_allclose(_loss_jt.numpy(), loss.numpy(), atol=1e-5, rtol=1e-5)
np.testing.assert_allclose(dinput_jt.numpy(), grad.numpy(), atol=1e-5, rtol=1e-5)
check(2,2,1,1,1)
check(50,20,16,30,10)
if jt.has_cuda:
with jt.flag_scope(use_cuda=1):
check(2,2,1,1,1)
check(50,20,16,30,10)
check_gpu_with_cpu(50,20,16,30,10)
class TestOther(unittest.TestCase):
def test_save(self):
pp = [1,2,jt.array([1,2,3]), {"a":[1,2,3], "b":jt.array([1,2,3])}]

View File

@ -0,0 +1,61 @@
# ***************************************************************
# Copyright (c) 2021 Jittor. All Rights Reserved.
# Maintainers:
# Zheng-Ning Liu <lzhengning@gmail.com>
# Dun Liang <randonlang@gmail.com>.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import unittest
import jittor as jt
import numpy as np
class TestRepeatOp(unittest.TestCase):
def test_repeat(self):
np_a = np.arange(5)
jt_a = jt.array(np_a)
np_b = np.tile(np_a, (2, 3))
jt_b = jt.repeat(jt_a, (2, 3))
assert np.allclose(np_b, jt_b.data)
np_b = np.tile(np_a, (2, 3, 1))
jt_b = jt.repeat(jt_a, (2, 3, 1))
assert np.allclose(np_b, jt_b.data)
np_a = np.arange(24).reshape(2, 3, 4)
jt_a = jt.array(np_a)
np_b = np.tile(np_a, (2, 3))
jt_b = jt.repeat(jt_a, (2, 3))
assert np.allclose(np_b, jt_b.data)
def test_highdim(self):
np_a = np.arange(64).reshape(2, 2, 2, 2, 2, 2)
jt_a = jt.array(np_a)
np_b = np.tile(np_a, (2, 3))
jt_b = jt.repeat(jt_a, (2, 3))
assert np.allclose(np_b, jt_b.data)
np_b = np.tile(np_a, (2, 1, 1, 3))
jt_b = jt.repeat(jt_a, (2, 1, 1, 3))
assert np.allclose(np_b, jt_b.data)
np_b = np.tile(np_a, (2, 1, 1, 1, 3, 1))
jt_b = jt.repeat(jt_a, (2, 1, 1, 1, 3, 1))
assert np.allclose(np_b, jt_b.data)
if __name__ == "__main__":
unittest.main()

View File

@ -954,6 +954,22 @@ class Tester(unittest.TestCase):
transform.ToTensor(),
])(img)
def test_not_pil_image(self):
img = jt.random((30,40,3))
result = transform.Compose([
transform.RandomAffine(20),
transform.ToTensor(),
])(img)
img = jt.random((30,40,3))
result = transform.Compose([
transform.ToPILImage(),
transform.Gray(),
transform.Resize(20),
transform.ToTensor(),
])(img)
if __name__ == '__main__':

View File

@ -152,7 +152,9 @@ class Crop:
self.left = left
self.height = height
self.width = width
def __call__(self, img):
def __call__(self, img:Image.Image):
if not isinstance(img, Image.Image):
img = to_pil_image(img)
return crop(img, self.top, self.left, self.height, self.width)
@ -181,6 +183,8 @@ class RandomCropAndResize:
self.interpolation = interpolation
def __call__(self, img:Image.Image):
if not isinstance(img, Image.Image):
img = to_pil_image(img)
width, height = img.size
scale = self.scale
ratio = self.ratio
@ -363,6 +367,8 @@ class RandomHorizontalFlip:
self.p = p
def __call__(self, img:Image.Image):
if not isinstance(img, Image.Image):
img = to_pil_image(img)
if random.random() < self.p:
return img.transpose(Image.FLIP_LEFT_RIGHT)
return img
@ -384,6 +390,8 @@ class CenterCrop:
self.size = _setup_size(size, error_msg="If size is a sequence, it should have 2 values")
def __call__(self, img:Image.Image):
if not isinstance(img, Image.Image):
img = to_pil_image(img)
width, height = img.size
return crop(img, (height - self.size[0]) / 2, (width - self.size[1]) / 2, self.size[0], self.size[1])
@ -682,6 +690,8 @@ class Resize:
self.size = _setup_size(size, error_msg="If size is a sequence, it should have 2 values")
self.mode = mode
def __call__(self, img:Image.Image):
if not isinstance(img, Image.Image):
img = to_pil_image(img)
return resize(img, self.size, self.mode)
class Gray:
@ -697,6 +707,8 @@ class Gray:
self.num_output_channels = num_output_channels
def __call__(self, img:Image.Image):
if not isinstance(img, Image.Image):
img = to_pil_image(img)
img = np.float32(img.convert('L')) / np.float32(255.0)
if self.num_output_channels == 1:
return img[np.newaxis, :]
@ -720,7 +732,9 @@ class RandomGray:
def __init__(self, p=0.1):
self.p = p
def __call__(self, img: Image.Image):
def __call__(self, img:Image.Image):
if not isinstance(img, Image.Image):
img = to_pil_image(img)
num_output_channels = _get_image_num_channels(img)
if random.random() < self.p:
return gray(img, num_output_channels=num_output_channels)
@ -742,6 +756,8 @@ class RandomCrop:
def __init__(self, size):
self.size = _setup_size(size, error_msg="If size is a sequence, it should have 2 values")
def __call__(self, img:Image.Image):
if not isinstance(img, Image.Image):
img = to_pil_image(img)
width, height = img.size
assert self.size[0] <= height and self.size[1] <= width, f"crop size exceeds the input image in RandomCrop, {(self.size, height, width)}"
top = np.random.randint(0,height-self.size[0]+1)
@ -835,7 +851,9 @@ class RandomVerticalFlip:
def __init__(self, p=0.5):
self.p = p
def __call__(self, img: Image.Image):
def __call__(self, img:Image.Image):
if not isinstance(img, Image.Image):
img = to_pil_image(img)
if random.random() < self.p:
return vflip(img)
return img
@ -918,13 +936,15 @@ class ColorJitter:
return transform
def __call__(self, img):
def __call__(self, img:Image.Image):
"""
Args::
[in] img (PIL Image): Input image.
Returns::
[out] PIL Image: Color jittered image.
"""
if not isinstance(img, Image.Image):
img = to_pil_image(img)
transform = self._get_transform(self.brightness, self.contrast, self.saturation, self.hue)
return transform(img)
@ -1002,7 +1022,7 @@ class RandomPerspective(object):
self.interpolation = interpolation
self.distortion_scale = distortion_scale
def __call__(self, img):
def __call__(self, img:Image.Image):
"""
Args:
img (PIL Image): Image to be Perspectively transformed.
@ -1011,7 +1031,7 @@ class RandomPerspective(object):
PIL Image: Random perspectivley transformed image.
"""
if not isinstance(img, Image.Image):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
img = to_pil_image(img)
if random.random() < self.p:
width, height = img.size
@ -1119,7 +1139,7 @@ class RandomResizedCrop(object):
j = (width - w) // 2
return i, j, h, w
def __call__(self, img):
def __call__(self, img:Image.Image):
"""
Args:
img (PIL Image): Image to be cropped and resized.
@ -1127,6 +1147,8 @@ class RandomResizedCrop(object):
Returns:
PIL Image: Randomly cropped and resized image.
"""
if not isinstance(img, Image.Image):
img = to_pil_image(img)
i, j, h, w = self.get_params(img, self.scale, self.ratio)
return F_pil.resized_crop(img, i, j, h, w, self.size, self.interpolation)
@ -1174,7 +1196,9 @@ class FiveCrop(object):
assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
self.size = size
def __call__(self, img):
def __call__(self, img:Image.Image):
if not isinstance(img, Image.Image):
img = to_pil_image(img)
return F_pil.five_crop(img, self.size)
def __repr__(self):
@ -1217,7 +1241,9 @@ class TenCrop(object):
self.size = size
self.vertical_flip = vertical_flip
def __call__(self, img):
def __call__(self, img:Image.Image):
if not isinstance(img, Image.Image):
img = to_pil_image(img)
return F_pil.ten_crop(img, self.size, self.vertical_flip)
def __repr__(self):
@ -1275,7 +1301,7 @@ class RandomRotation(object):
return angle
def __call__(self, img):
def __call__(self, img:Image.Image):
"""
Args:
img (PIL Image): Image to be rotated.
@ -1283,7 +1309,8 @@ class RandomRotation(object):
Returns:
PIL Image: Rotated image.
"""
if not isinstance(img, Image.Image):
img = to_pil_image(img)
angle = self.get_params(self.degrees)
return F_pil.rotate(img, angle, self.resample, self.expand, self.center, self.fill)
@ -1405,13 +1432,15 @@ class RandomAffine(object):
return angle, translations, scale, shear
def __call__(self, img):
def __call__(self, img:Image.Image):
"""
img (PIL Image): Image to be transformed.
Returns:
PIL Image: Affine transformed image.
"""
if not isinstance(img, Image.Image):
img = to_pil_image(img)
ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img.size)
return F_pil.affine(img, *ret, resample=self.resample, fillcolor=self.fillcolor)