mirror of https://github.com/Jittor/Jittor
Merge branch 'master' of github.com:Jittor/jittor
This commit is contained in:
commit
500dfc6ee3
|
@ -0,0 +1,76 @@
|
|||
Jittor调试技巧
|
||||
=====================
|
||||
|
||||
该文档包含了几种异常情况的调试方法和技巧。
|
||||
|
||||
## 爆Nan、Inf
|
||||
|
||||
在模型训练的过程中,可能因为数值不稳定而出现Nan或者Inf,为了帮助您定位出现nan的代码,您可以设置如下环境变量:
|
||||
|
||||
```bash
|
||||
export JT_CHECK_NAN=1
|
||||
export trace_py_var=3
|
||||
```
|
||||
|
||||
其中,环境变量`JT_CHECK_NAN=1`的用途是:当算子的输出出现异常浮点数时,自动报错并停止程序,环境变量`trace_py_var=3`的用途是:输出算子对应的Python代码行数,3代表输出的详细等级,为最高等级。
|
||||
|
||||
需要注意的是,开启这两个特性之后,jittor速度会大幅下降,并且触发重编译,请不要在训练环境或者生产环境开启该模式,也不要长时间开启该模式。
|
||||
|
||||
## 错误信息定位不准确
|
||||
|
||||
Jittor框架默认采用延迟执行(Lazy execution)的方式进行加速,算子的执行和创建是不同步的,这可能导致报错信息定位不准确,您可以手动关闭延迟执行,采取立刻执行(eager execution)的模式,使用如下环境变量即可:
|
||||
|
||||
```bash
|
||||
export lazy_execution=0
|
||||
```
|
||||
|
||||
或者在python代码中通过flag关闭
|
||||
```python
|
||||
jt.flags.lazy_execution=0
|
||||
```
|
||||
|
||||
## 内存不足
|
||||
|
||||
当您发现Jittor由于内存相关问题,无法运行时,Jittor会向您报告内存使用情况,内存不足可能有两种情况:
|
||||
|
||||
1. 训练模型过大,一个迭代就崩溃报错。
|
||||
2. 多次迭代的过程中,内存占用不断增长,直到最后内存耗尽报错。
|
||||
|
||||
**对于第一种情况** ,您可能需要调整模型或者数据大小,或者使用[多卡训练](jittor.mpi),此外,您还可以在每个迭代内部,让Jittor强制回收内存:
|
||||
|
||||
```python
|
||||
for ...:
|
||||
...
|
||||
jt.sync_all()
|
||||
jt.gc()
|
||||
```
|
||||
|
||||
如果您使用到了CUDA和卷积,还有可能是卷积消耗的临时空间过大,在这种情况下,可以关闭cudnn的临时内存申请,请将如下代码插入到最开始:
|
||||
|
||||
```python
|
||||
jt.cudnn.set_max_workspace_ratio(0.0)
|
||||
```
|
||||
|
||||
**对于第二种情况**,可能是存在内存内存泄漏,请检查您是否存在全局变量没有释放,或者全局变量没有停止梯度,导致计算图不断增加,检查方法如下,您可以在每个迭代内部,插入如下调试代码:
|
||||
|
||||
```python
|
||||
for ...:
|
||||
...
|
||||
jt.sync_all()
|
||||
jt.display_memory_info()
|
||||
```
|
||||
|
||||
Jittor会输出内存消耗,以及计算图的大小`lived_var,lived_op`,以及用户持有的变量数`hold_var`, 如果计算图规模不断增大,请检查代码,或者提交github issue联系我们,并且附上错误日志和代码复现脚本。
|
||||
|
||||
|
||||
## 段错误
|
||||
|
||||
如果Jittor出现了段错误,建议您将错误提交github issue联系我们,并且附上错误日志和代码复现脚本。您也可以使用如下环境变量对程序以及框架进行诊断:
|
||||
|
||||
```bash
|
||||
export debug=1
|
||||
export gdb_attach=1
|
||||
```
|
||||
|
||||
其中,环境变量`debug=1`代表开启jittor的debug模式,性能会大幅下降,但会保留调试信息,`gdb_attach=1`将会自动将gdb贴在jittor的主进程上,方便您进行单步调试。关于gdb的使用,您可以参考[GDB Cheat Sheet](https://darkdust.net/files/GDB%20Cheat%20Sheet.pdf)
|
||||
|
|
@ -45,7 +45,8 @@ language = 'zh_CN'
|
|||
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
|
||||
# ones.
|
||||
extensions = [
|
||||
'recommonmark',
|
||||
# 'recommonmark',
|
||||
'myst_parser',
|
||||
'sphinx.ext.autodoc',
|
||||
# Auto-generate section labels.
|
||||
'sphinx.ext.autosectionlabel',
|
||||
|
|
|
@ -32,10 +32,22 @@
|
|||
jittor.loss3d
|
||||
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
:caption: 计图模型库:
|
||||
|
||||
JDet
|
||||
segmentation-jittor
|
||||
InstanceSegmentation-jittor
|
||||
gan-jittor
|
||||
PointCloudLib
|
||||
jrender
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: 其他:
|
||||
|
||||
Jittor调试技巧
|
||||
教程 <https://cg.cs.tsinghua.edu.cn/jittor/tutorial/>
|
||||
|
||||
Indices and tables
|
||||
|
|
|
@ -1,7 +1,37 @@
|
|||
jittor.mpi
|
||||
=====================
|
||||
|
||||
这里是Jittor的MPI模块的API文档,您可以通过`from jittor import mpi`来获取该模块。
|
||||
计图分布式基于MPI(Message Passing Interface),本文档主要阐述使用计图MPI,进行多卡和分布式训练的教程。
|
||||
|
||||
|
||||
## 计图MPI安装
|
||||
|
||||
计图依赖`OpenMPI`,用户可以使用如下命令安装`OpenMPI`:
|
||||
|
||||
```bash
|
||||
sudo apt install openmpi-bin openmpi-common libopenmpi-dev
|
||||
```
|
||||
|
||||
计图会自动检测环境变量中是否包含`mpicc`,如果计图成功的检测到了`mpicc`,那么会输出如下信息:
|
||||
|
||||
```
|
||||
[i 0502 14:09:55.758481 24 __init__.py:203] Found mpicc(1.10.2) at /usr/bin/mpicc
|
||||
```
|
||||
|
||||
如果计图没有在环境变量中找到mpi,用户也可以手动指定mpicc的路径告诉计图,添加环境变量即可:`export mpicc_path=/you/mpicc/path`
|
||||
|
||||
`OpenMPI`安装完成以后,用户无需修改代码,需要做的仅仅是修改启动命令行,计图就会用数据并行的方式自动完成并行操作。
|
||||
|
||||
```bash
|
||||
# 单卡训练代码
|
||||
python3.7 -m jittor.test.test_resnet
|
||||
# 分布式多卡训练代码
|
||||
mpirun -np 4 python3.7 -m jittor.test.test_resnet
|
||||
# 指定特定显卡的多卡训练代码
|
||||
CUDA_VISIBLE_DEVICES="2,3" mpirun -np 2 python3.7 -m jittor.test.test_resnet
|
||||
```
|
||||
|
||||
这种便捷性的背后是计图的分布式算子的支撑,计图支持的mpi算子后端会使用nccl进行进一步的加速。计图所有分布式算法的开发均在Python前端完成,这让分布式算法的灵活度增强,开发分布式算法的难度也大大降低。
|
||||
|
||||
## 如何从单卡代码适配多卡代码
|
||||
|
||||
|
@ -11,6 +41,8 @@ jittor.mpi
|
|||
* jittor.nn.BatchNorm*: 同步batch norm
|
||||
* jittor.dataset: 自动数据并行
|
||||
|
||||
用户在使用MPI进行分布式训练时,计图内部的Dataset类会自动并行分发数据,需要注意的是Dataset类中设置的Batch size是**所有节点的batch size之和**,也就是总batch size, 不是单个节点接收到的batch size。
|
||||
|
||||
大部分情况下,单卡训练的代码可以直接使用`mpirun`实现分布式多卡运行。 但仍然如下几种情况下,需要对代码进行调整:
|
||||
|
||||
1. 对硬盘进行写操作(保存模型,保存曲线)
|
||||
|
@ -93,10 +125,30 @@ def val(epoch):
|
|||
......
|
||||
```
|
||||
|
||||
下面是 jittor 的 mpi api reference.
|
||||
|
||||
## MPI接口
|
||||
|
||||
下面是 jittor 的 mpi api reference.
|
||||
目前MPI开放接口如下:
|
||||
|
||||
* `jt.in_mpi`: 当计图不在MPI环境下时,`jt.mpi == False`, 用户可以用这个判断是否在mpi环境下。
|
||||
* `jt.world_rank`: 获取当前进程总数量,如果没有用mpi,则为1。
|
||||
* `jt.rank`: 获取当前进程的编号,区间为`0 ~ jt.world_rank-1`, 如果没有用mpi,则为0。
|
||||
* `jt.mpi`: 计图的MPI模块。
|
||||
* `jt.Module.mpi_param_broadcast(root=0)`: 将模块的参数从root节点广播给其他节点。
|
||||
* `jt.mpi.mpi_reduce(x, op='add', root=0)`: 将所有节点的变量x使用算子op,reduce到root节点。如果op是'add'或者'sum',该接口会把所有变量求和,如果op是'mean',该接口会取均值。
|
||||
|
||||
<img src="https://cg.cs.tsinghua.edu.cn/jittor/images/tutorial/2020-5-2-16-44-distributed/mpi_reduce.png">
|
||||
|
||||
* `jt.mpi.mpi_broadcast(x, root=0)`: 将变量x从root节点广播到所有节点。
|
||||
|
||||
<img src="https://cg.cs.tsinghua.edu.cn/jittor/images/tutorial/2020-5-2-16-44-distributed/mpi_broadcast.png">
|
||||
|
||||
* `jt.mpi.mpi_all_reduce(x, op='add')`: 将所有节点的变量x使用一起reduce,并且吧reduce的结果再次广播到所有节点。如果op是'add'或者'sum',该接口会把所有变量求和,如果op是'mean',该接口会取均值。
|
||||
|
||||
<img src="https://cg.cs.tsinghua.edu.cn/jittor/images/tutorial/2020-5-2-16-44-distributed/mpi_all_reduce.png">
|
||||
|
||||
|
||||
|
||||
```eval_rst
|
||||
.. automodule:: jittor_mpi_core
|
||||
|
@ -106,3 +158,56 @@ def val(epoch):
|
|||
:members:
|
||||
:undoc-members:
|
||||
```
|
||||
|
||||
## 实例:MPI实现分布式同步批归一化层
|
||||
|
||||
|
||||
下面的代码是使用计图实现分布式同步批归一化层的实例代码,在原来批归一化层的基础上,只需增加三行代码,就可以实现分布式的batch norm,添加的代码如下:
|
||||
|
||||
```python
|
||||
# 将均值和方差,通过all reduce同步到所有节点
|
||||
if self.sync and jt.mpi:
|
||||
xmean = xmean.mpi_all_reduce("mean")
|
||||
x2mean = x2mean.mpi_all_reduce("mean")
|
||||
```
|
||||
|
||||
> 注:计图内部已经实现了同步的批归一化层,用户不需要自己实现
|
||||
|
||||
分布式同步批归一化层的完整代码:
|
||||
|
||||
```python
|
||||
class BatchNorm(Module):
|
||||
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=None, is_train=True, sync=True):
|
||||
assert affine == None
|
||||
|
||||
self.sync = sync
|
||||
self.num_features = num_features
|
||||
self.is_train = is_train
|
||||
self.eps = eps
|
||||
self.momentum = momentum
|
||||
self.weight = init.constant((num_features,), "float32", 1.0)
|
||||
self.bias = init.constant((num_features,), "float32", 0.0)
|
||||
self.running_mean = init.constant((num_features,), "float32", 0.0).stop_grad()
|
||||
self.running_var = init.constant((num_features,), "float32", 1.0).stop_grad()
|
||||
|
||||
def execute(self, x):
|
||||
if self.is_train:
|
||||
xmean = jt.mean(x, dims=[0,2,3], keepdims=1)
|
||||
x2mean = jt.mean(x*x, dims=[0,2,3], keepdims=1)
|
||||
# 将均值和方差,通过all reduce同步到所有节点
|
||||
if self.sync and jt.mpi:
|
||||
xmean = xmean.mpi_all_reduce("mean")
|
||||
x2mean = x2mean.mpi_all_reduce("mean")
|
||||
|
||||
xvar = x2mean-xmean*xmean
|
||||
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
|
||||
self.running_mean += (xmean.sum([0,2,3])-self.running_mean)*self.momentum
|
||||
self.running_var += (xvar.sum([0,2,3])-self.running_var)*self.momentum
|
||||
else:
|
||||
running_mean = self.running_mean.broadcast(x, [0,2,3])
|
||||
running_var = self.running_var.broadcast(x, [0,2,3])
|
||||
norm_x = (x-running_mean)/jt.sqrt(running_var+self.eps)
|
||||
w = self.weight.broadcast(x, [0,2,3])
|
||||
b = self.bias.broadcast(x, [0,2,3])
|
||||
return norm_x * w + b
|
||||
```
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.2.3.92'
|
||||
__version__ = '1.2.3.102'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
@ -69,6 +69,15 @@ def safeunpickle(path):
|
|||
except:
|
||||
raise RuntimeError("pytorch need to be installed when load pth format.")
|
||||
model_dict = torch.load(path, map_location=torch.device('cpu'))
|
||||
try:
|
||||
for k, v in model_dict.items():
|
||||
try:
|
||||
if not isinstance(v, np.ndarray) and hasattr(v, "cpu"):
|
||||
model_dict[k] = v.cpu().detach().numpy()
|
||||
except:
|
||||
pass
|
||||
except:
|
||||
pass
|
||||
return model_dict
|
||||
with open(path, "rb") as f:
|
||||
s = f.read()
|
||||
|
@ -332,12 +341,12 @@ def std(x):
|
|||
return out
|
||||
Var.std = std
|
||||
|
||||
def norm(x, k=2, dim=-1, keepdim=False):
|
||||
assert k==1 or k==2
|
||||
if k==1:
|
||||
def norm(x, p=2, dim=-1, keepdim=False, eps=1e-30):
|
||||
assert p==1 or p==2
|
||||
if p==1:
|
||||
return x.abs().sum(dim, keepdim)
|
||||
if k==2:
|
||||
return (x.sqr()).sum(dim, keepdim).maximum(1e-6).sqrt()
|
||||
if p==2:
|
||||
return (x.sqr()).sum(dim, keepdim).maximum(eps).sqrt()
|
||||
Var.norm = norm
|
||||
|
||||
origin_reshape = reshape
|
||||
|
@ -796,12 +805,30 @@ class Module:
|
|||
return _uniq(ps)
|
||||
|
||||
def named_parameters(self):
|
||||
ps = self.parameters()
|
||||
return [ (p.name(), p) for p in ps ]
|
||||
uniq_set = set()
|
||||
ps = {}
|
||||
stack = []
|
||||
def callback(parents, k, v, n):
|
||||
stack.append(str(k))
|
||||
dc = v.__dict__
|
||||
if isinstance(v, nn.ParameterList):
|
||||
dc = v.params
|
||||
for k2, p in dc.items():
|
||||
if isinstance(k2, str) and k2.startswith("_"): continue
|
||||
if isinstance(p, Var):
|
||||
if id(p) in uniq_set: continue
|
||||
uniq_set.add(id(p))
|
||||
pname = ".".join(stack[1:]+[str(k2)])
|
||||
ps[pname] = p
|
||||
if len(pname) > len(p.name()):
|
||||
p.name(pname)
|
||||
def callback_leave(parents, k, v, n):
|
||||
stack.pop()
|
||||
self.dfs([], None, callback, callback_leave)
|
||||
return ps
|
||||
|
||||
def state_dict(self):
|
||||
ps = self.parameters()
|
||||
return { p.name(): p for p in ps }
|
||||
return self.named_parameters()
|
||||
|
||||
def load_state_dict(self, params):
|
||||
self.load_parameters(params)
|
||||
|
@ -1011,10 +1038,13 @@ Arguments of hook are defined as::
|
|||
>>> net.save('net.pkl')
|
||||
>>> net.load('net.pkl')
|
||||
'''
|
||||
params = self.parameters()
|
||||
params = self.named_parameters()
|
||||
params_dict = {}
|
||||
for p in params:
|
||||
params_dict[p.name()] = p.data
|
||||
for k, v in params.items():
|
||||
if isinstance(v, Var):
|
||||
params_dict[k] = v.numpy()
|
||||
else:
|
||||
params_dict[k] = v
|
||||
safepickle(params_dict, path)
|
||||
|
||||
def load(self, path: str):
|
||||
|
|
|
@ -127,6 +127,7 @@ class Dataset(object):
|
|||
self.epoch_id = 0
|
||||
self.sampler = None
|
||||
self._disable_workers = False
|
||||
self._shuffle_rng = np.random.default_rng(1)
|
||||
|
||||
def __getitem__(self, index):
|
||||
raise NotImplementedError
|
||||
|
@ -214,8 +215,8 @@ class Dataset(object):
|
|||
jittor_utils.cc.init_subprocess()
|
||||
jt.jt_init_subprocess()
|
||||
seed = jt.get_seed()
|
||||
wseed = (seed ^ worker_id) ^ 1234
|
||||
jt.set_seed(wseed)
|
||||
wseed = (seed ^ (worker_id*1167)) ^ 1234
|
||||
jt.set_global_seed(wseed)
|
||||
# parallel_op_compiler still problematic,
|
||||
# it is not work on ubuntu 16.04. but worked on ubuntu 20.04
|
||||
# it seems like the static value of parallel compiler
|
||||
|
@ -426,7 +427,10 @@ Example::
|
|||
elif self.shuffle == False:
|
||||
index_list = get_order_list(self.total_len)
|
||||
else:
|
||||
index_list = get_random_list(self.total_len)
|
||||
# using _shuffle_rng to generate multiprocess
|
||||
# consist shuffle list
|
||||
# index_list = get_random_list(self.total_len)
|
||||
index_list = self._shuffle_rng.permutation(range(self.total_len))
|
||||
|
||||
# scatter index_list for all mpi process
|
||||
# scatter rule:
|
||||
|
|
|
@ -7,6 +7,8 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
import os
|
||||
import string
|
||||
import numpy as np
|
||||
import gzip
|
||||
from PIL import Image
|
||||
|
@ -94,3 +96,105 @@ class MNIST(Dataset):
|
|||
for url, md5 in resources:
|
||||
filename = url.rpartition('/')[2]
|
||||
download_url_to_local(url, filename, self.data_root, md5)
|
||||
|
||||
class EMNIST(Dataset):
|
||||
'''
|
||||
Jittor's own class for loading EMNIST dataset.
|
||||
|
||||
Args::
|
||||
|
||||
[in] data_root(str): your data root.
|
||||
[in] split(str): one of 'byclass', 'bymerge', 'balanced', 'letters', 'digits', 'mnist'.
|
||||
[in] train(bool): choose model train or val.
|
||||
[in] download(bool): Download data automatically if download is Ture.
|
||||
[in] batch_size(int): Data batch size.
|
||||
[in] shuffle(bool): Shuffle data if true.
|
||||
[in] transform(jittor.transform): transform data.
|
||||
|
||||
Example::
|
||||
|
||||
from jittor.dataset.mnist import EMNIST
|
||||
train_loader = EMNIST(train=True).set_attrs(batch_size=16, shuffle=True)
|
||||
for i, (imgs, target) in enumerate(train_loader):
|
||||
...
|
||||
'''
|
||||
|
||||
_merged_classes = {'c', 'i', 'j', 'k', 'l', 'm', 'o', 'p', 's', 'u', 'v', 'w', 'x', 'y', 'z'}
|
||||
_all_classes = set(string.digits + string.ascii_letters)
|
||||
classes_split_dict = {
|
||||
'byclass': sorted(list(_all_classes)),
|
||||
'bymerge': sorted(list(_all_classes - _merged_classes)),
|
||||
'balanced': sorted(list(_all_classes - _merged_classes)),
|
||||
'letters': ['N/A'] + list(string.ascii_lowercase),
|
||||
'digits': list(string.digits),
|
||||
'mnist': list(string.digits),
|
||||
}
|
||||
|
||||
def __init__(self, data_root=dataset_root+"/emnist_data/",
|
||||
split='byclass',
|
||||
train=True,
|
||||
download=True,
|
||||
batch_size = 16,
|
||||
shuffle = False,
|
||||
transform=None):
|
||||
# if you want to test resnet etc you should set input_channel = 3, because the net set 3 as the input dimensions
|
||||
super().__init__()
|
||||
self.data_root = data_root
|
||||
self.is_train = train
|
||||
self.transform = transform
|
||||
self.batch_size = batch_size
|
||||
self.shuffle = shuffle
|
||||
if download == True:
|
||||
self.download_url()
|
||||
data_root = os.path.join(data_root, "gzip")
|
||||
|
||||
filesname = [
|
||||
f"emnist-{split}-train-images-idx3-ubyte.gz",
|
||||
f"emnist-{split}-t10k-images-idx3-ubyte.gz",
|
||||
f"emnist-{split}-train-labels-idx1-ubyte.gz",
|
||||
f"emnist-{split}-t10k-labels-idx1-ubyte.gz"
|
||||
]
|
||||
for i in range(4):
|
||||
filesname[i] = os.path.join(data_root, filesname[i])
|
||||
self.mnist = {}
|
||||
if self.is_train:
|
||||
with gzip.open(filesname[0], 'rb') as f:
|
||||
self.mnist["images"] = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1,28, 28).transpose(0,2,1)
|
||||
with gzip.open(filesname[2], 'rb') as f:
|
||||
self.mnist["labels"] = np.frombuffer(f.read(), np.uint8, offset=8)
|
||||
else:
|
||||
with gzip.open(filesname[1], 'rb') as f:
|
||||
self.mnist["images"] = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1,28, 28).transpose(0,2,1)
|
||||
with gzip.open(filesname[3], 'rb') as f:
|
||||
self.mnist["labels"] = np.frombuffer(f.read(), np.uint8, offset=8)
|
||||
assert(self.mnist["images"].shape[0] == self.mnist["labels"].shape[0])
|
||||
self.total_len = self.mnist["images"].shape[0]
|
||||
# this function must be called
|
||||
self.set_attrs(total_len = self.total_len)
|
||||
|
||||
def __getitem__(self, index):
|
||||
img = Image.fromarray(self.mnist['images'][index]).convert('RGB')
|
||||
if self.transform:
|
||||
img = self.transform(img)
|
||||
return trans.to_tensor(img), self.mnist['labels'][index]
|
||||
|
||||
def download_url(self):
|
||||
'''
|
||||
Download mnist data set function, this function will be called when download is True.
|
||||
'''
|
||||
resources = [
|
||||
("https://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zip", "58c8d27c78d21e728a6bc7b3cc06412e"),
|
||||
]
|
||||
|
||||
for url, md5 in resources:
|
||||
filename = "emnist.zip"
|
||||
download_url_to_local(url, filename, self.data_root, md5)
|
||||
import zipfile
|
||||
zf = zipfile.ZipFile(os.path.join(self.data_root, filename))
|
||||
try:
|
||||
zf.extractall(path=self.data_root)
|
||||
except RuntimeError as e:
|
||||
print(e)
|
||||
raise
|
||||
zf.close()
|
||||
|
||||
|
|
|
@ -0,0 +1,125 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Guoye Yang <498731903@qq.com>
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
//
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include <algorithm>
|
||||
#include "var.h"
|
||||
#include "cub_cumsum_op.h"
|
||||
#include <vector>
|
||||
#include "executor.h"
|
||||
#include "ops/op_register.h"
|
||||
#ifdef JIT_cuda
|
||||
#include <cub/cub.cuh>
|
||||
#include <cub/block/block_scan.cuh>
|
||||
#include <thrust/iterator/reverse_iterator.h>
|
||||
#endif
|
||||
|
||||
namespace jittor {
|
||||
|
||||
#ifndef JIT
|
||||
|
||||
static auto make_cub_cumsum = get_op_info("cub_cumsum")
|
||||
.get_constructor<VarPtr, Var*, bool>();
|
||||
|
||||
CubCumsumOp::CubCumsumOp(Var* x, bool reverse) : x(x),reverse(reverse) {
|
||||
flags.set(NodeFlags::_cpu, 0);
|
||||
flags.set(NodeFlags::_cuda, 1);
|
||||
y = create_output(nullptr, x->dtype());
|
||||
}
|
||||
|
||||
void CubCumsumOp::infer_shape() {
|
||||
ASSERT(x->shape.size() == 1 || x->shape.size() == 2); //TODO:support batch_cumsum
|
||||
y->set_shape(x->shape);
|
||||
}
|
||||
|
||||
void CubCumsumOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << x->dtype();
|
||||
jk << _CS("][Ty:") << y->dtype();
|
||||
jk << _CS("][reverse:") << reverse;
|
||||
jk << _CS("]");
|
||||
}
|
||||
|
||||
VarPtr CubCumsumOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||
return make_cub_cumsum(dout, !reverse);
|
||||
// return ArgsortOp::get_grad(out, dout, v, v_index, v->shape.size()-1, y);
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
#ifdef JIT_cuda
|
||||
|
||||
#define ITEMS_PER_THREAD 4
|
||||
#define BLOCK_THREADS 1024
|
||||
|
||||
__global__ void BlockScanKernel(Tx* __restrict__ xp, Ty* __restrict__ yp, int batch_num, int num_items) {
|
||||
typedef cub::BlockScan<Tx, BLOCK_THREADS> BlockScanT;
|
||||
__shared__ typename BlockScanT::TempStorage temp_storage;
|
||||
|
||||
int batch_id = blockIdx.x;
|
||||
int offset = threadIdx.x * ITEMS_PER_THREAD;
|
||||
for (int block_offset = offset; block_offset < num_items; block_offset += BLOCK_THREADS * ITEMS_PER_THREAD) {
|
||||
int items = ITEMS_PER_THREAD;
|
||||
if (block_offset + ITEMS_PER_THREAD > num_items) {
|
||||
items = num_items - block_offset;
|
||||
}
|
||||
Tx thread_data[ITEMS_PER_THREAD];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < ITEMS_PER_THREAD; ++i) {
|
||||
if (i<items)
|
||||
#if reverse
|
||||
thread_data[i] = xp[batch_id * num_items + (num_items - 1 - (block_offset + i))];
|
||||
#else
|
||||
thread_data[i] = xp[batch_id * num_items + block_offset + i];
|
||||
#endif
|
||||
}
|
||||
BlockScanT(temp_storage).InclusiveSum(thread_data, thread_data);
|
||||
__syncthreads();
|
||||
#pragma unroll
|
||||
for (int i = 0; i < ITEMS_PER_THREAD; ++i) {
|
||||
if (i<items)
|
||||
#if reverse
|
||||
yp[batch_id * num_items + (num_items - 1 - (block_offset + i))] = thread_data[i];
|
||||
#else
|
||||
yp[batch_id * num_items + block_offset + i] = thread_data[i];
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void CubCumsumOp::jit_run() {
|
||||
auto* __restrict__ xp = x->ptr<Tx>();
|
||||
auto* __restrict__ yp = y->ptr<Ty>();
|
||||
if (x->shape.size() == 1){
|
||||
int num_items = x->shape[0];
|
||||
|
||||
// Determine temporary device storage requirements for inclusive prefix sum
|
||||
void *d_temp_storage = NULL;
|
||||
size_t temp_storage_bytes = 0, temp_storage_allocation;
|
||||
cub::DeviceScan::InclusiveSum(NULL, temp_storage_bytes, xp, yp, num_items);
|
||||
d_temp_storage = exe.temp_allocator->alloc(temp_storage_bytes, temp_storage_allocation);
|
||||
// Allocate temporary storage for inclusive prefix sum
|
||||
// cudaMalloc(&d_temp_storage, temp_storage_bytes);
|
||||
// Run inclusive prefix sum
|
||||
if (reverse) {
|
||||
auto xp_ = thrust::make_reverse_iterator(xp + num_items);
|
||||
auto yp_ = thrust::make_reverse_iterator(yp + num_items);
|
||||
cub::DeviceScan::InclusiveSum(d_temp_storage, temp_storage_bytes, xp_, yp_, num_items);
|
||||
} else {
|
||||
cub::DeviceScan::InclusiveSum(d_temp_storage, temp_storage_bytes, xp, yp, num_items);
|
||||
}
|
||||
// yp <-- [8, 14, 21, 26, 29, 29, 38]
|
||||
exe.temp_allocator->free(d_temp_storage, temp_storage_bytes, temp_storage_allocation);
|
||||
} else {
|
||||
int batch_num = x->shape[0];
|
||||
int num_items = x->shape[1];
|
||||
BlockScanKernel<<<batch_num, BLOCK_THREADS>>>(xp, yp, batch_num, num_items);
|
||||
}
|
||||
}
|
||||
#endif // JIT_cuda
|
||||
#endif // JIT
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,28 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Guoye Yang <498731903@qq.com>
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
//
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "op.h"
|
||||
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct CubCumsumOp : Op {
|
||||
Var* x, * y;
|
||||
bool reverse;
|
||||
|
||||
CubCumsumOp(Var* x, bool reverse=false);
|
||||
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
|
||||
|
||||
void infer_shape() override;
|
||||
const char* name() const override { return "cub_cumsum"; }
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -20,6 +20,8 @@ namespace jittor {
|
|||
|
||||
CublasMatmulOp::CublasMatmulOp(Var* a, Var* b, bool trans_a, bool trans_b)
|
||||
: a(a), b(b), trans_a(trans_a), trans_b(trans_b) {
|
||||
flags.set(NodeFlags::_cuda, 1);
|
||||
flags.set(NodeFlags::_cpu, 0);
|
||||
// TODO: support int8 * int8
|
||||
ASSERT(a->dtype().is_float() && b->dtype().is_float()) << "type of two inputs should be the same";
|
||||
// TODO: support diffrent input type
|
||||
|
@ -51,7 +53,6 @@ void CublasMatmulOp::jit_prepare(JK& jk) {
|
|||
}
|
||||
|
||||
#else // JIT
|
||||
#ifdef JIT_cpu
|
||||
#pragma clang diagnostic ignored "-Wtautological-compare"
|
||||
void CublasMatmulOp::jit_run() {
|
||||
cublasHandle_t& handle_ = cublas_handle;
|
||||
|
@ -78,7 +79,6 @@ void CublasMatmulOp::jit_run() {
|
|||
a->ptr<T>(), '@Trans_a' == 'N' ? m : n, &beta,
|
||||
c->ptr<T>(), k));
|
||||
}
|
||||
#endif
|
||||
#endif // JIT
|
||||
|
||||
} // jittor
|
||||
|
|
|
@ -19,11 +19,12 @@ struct cublas_initer {
|
|||
inline cublas_initer() {
|
||||
if (!get_device_count()) return;
|
||||
checkCudaErrors(cublasCreate(&cublas_handle));
|
||||
LOGv << "cublasCreate finished";
|
||||
LOGv << "cublasCreate finished" << (void*)cublas_handle;
|
||||
}
|
||||
|
||||
inline ~cublas_initer() {
|
||||
if (!get_device_count()) return;
|
||||
LOGv << "cublasDestroy:" << (void*)cublas_handle;
|
||||
checkCudaErrors(cublasDestroy(cublas_handle));
|
||||
LOGv << "cublasDestroy finished";
|
||||
}
|
||||
|
|
|
@ -100,15 +100,15 @@ def repeat(x, *shape):
|
|||
x_shape = (len_shape - len_x_shape) * [1] + x.shape
|
||||
x = x.broadcast(x_shape)
|
||||
elif len_x_shape > len_shape:
|
||||
rep_shape = (len_x_shape - len_shape) * [1] + shape
|
||||
#TODO if input.shape[i]=1, no add [1]
|
||||
rep_shape = (len_x_shape - len_shape) * [1] + list(shape)
|
||||
|
||||
reshape_shape = []
|
||||
broadcast_shape = []
|
||||
for x_s,r_s in zip(x_shape,rep_shape):
|
||||
reshape_shape.append(1)
|
||||
if r_s != 1:
|
||||
reshape_shape.append(1)
|
||||
broadcast_shape.append(r_s)
|
||||
reshape_shape.append(x_s)
|
||||
|
||||
broadcast_shape.append(r_s)
|
||||
broadcast_shape.append(1)
|
||||
|
||||
x = x.reshape(reshape_shape)
|
||||
|
@ -344,7 +344,7 @@ def cross(input, other, dim=-1):
|
|||
return jt.contrib.concat([a1.unsqueeze(dim),a2.unsqueeze(dim),a3.unsqueeze(dim)], dim=dim)
|
||||
jt.Var.cross = cross
|
||||
|
||||
def normalize(input, p=2, dim=1, eps=1e-12):
|
||||
def normalize(input, p=2, dim=1, eps=1e-30):
|
||||
r'''
|
||||
Performs L_p normalization of inputs over specified dimension.
|
||||
|
||||
|
@ -376,9 +376,7 @@ def normalize(input, p=2, dim=1, eps=1e-12):
|
|||
[0.02647221 0.59484214 0.80340654]
|
||||
[0.6910677 0.58067477 0.4303977 ]]
|
||||
'''
|
||||
assert p == 2
|
||||
if p == 2:
|
||||
return input / jt.maximum(input.sqr().sum(dim,True).sqrt(), eps)
|
||||
return input / input.norm(p, dim, True, eps)
|
||||
jt.Var.normalize = normalize
|
||||
|
||||
def unbind(x, dim=0):
|
||||
|
@ -661,32 +659,66 @@ def _prod(x,dim=0):
|
|||
return jt.exp(x)
|
||||
|
||||
|
||||
def cumsum_forward(np, data):
|
||||
a = data['inputs'][0]
|
||||
b = data['outputs'][0]
|
||||
np.cumsum(a, axis=1, out=b)
|
||||
def numpy_cumsum(x, dim=None):
|
||||
def cumsum_forward(np, data):
|
||||
dim = data['inputs'][1].item()
|
||||
a = data['inputs'][0]
|
||||
b = data['outputs'][0]
|
||||
np.cumsum(a, axis=dim, out=b)
|
||||
|
||||
def cumsum_backward(np, data):
|
||||
dout = data['dout']
|
||||
out = data['outputs'][0]
|
||||
np.cumsum(dout[:, ::-1], axis=1, out=out)
|
||||
np.copyto(out, out[:, ::-1])
|
||||
def cumsum_backward(np, data):
|
||||
dim = data['inputs'][1].item()
|
||||
dout = data['dout']
|
||||
out = data['outputs'][0]
|
||||
np.cumsum(dout[..., ::-1], axis=dim, out=out)
|
||||
np.copyto(out, out[..., ::-1])
|
||||
if (dim == None):
|
||||
dim = -1
|
||||
assert(dim >= -1 and dim < len(x.shape))
|
||||
dim_var = jt.array([dim],dtype=int)
|
||||
return jt.numpy_code(x.shape, x.dtype, [x, dim_var.detach()], cumsum_forward, [cumsum_backward])
|
||||
|
||||
def cub_cumsum(x, dim=None):
|
||||
if (dim == None):
|
||||
dim = -1
|
||||
assert(dim >= -1 and dim < len(x.shape))
|
||||
shape = x.shape
|
||||
if (dim != -1 and dim != len(shape) - 1):
|
||||
order = range(len(shape))
|
||||
order[dim], order[-1] = order[-1], order[dim]
|
||||
shape[dim], shape[-1] = shape[-1], shape[dim]
|
||||
x = x.permute(order)
|
||||
if (len(shape) > 2):
|
||||
x = x.reshape([-1, shape[-1]])
|
||||
x = jt.compile_extern.cub_ops.cub_cumsum(x)
|
||||
if (len(shape) > 2):
|
||||
x = x.reshape(shape)
|
||||
if (dim != -1 and dim != len(shape) - 1):
|
||||
x = x.permute(order)
|
||||
return x
|
||||
|
||||
def cumsum(x, dim=None):
|
||||
'''
|
||||
Parameters:
|
||||
-----------
|
||||
x: [batch_size, N], jt.var
|
||||
x: jt.var
|
||||
dim: int
|
||||
|
||||
Returns:
|
||||
--------
|
||||
the cumulative sum of x
|
||||
the cumulative sum in dim of x
|
||||
'''
|
||||
return jt.numpy_code(x.shape, x.dtype, [x], cumsum_forward, [cumsum_backward])
|
||||
if (dim == None):
|
||||
dim = -1
|
||||
assert(dim >= -1 and dim < len(x.shape))
|
||||
if jt.has_cuda:
|
||||
return cub_cumsum(x, dim)
|
||||
else:
|
||||
return numpy_cumsum(x, dim)
|
||||
|
||||
jt.Var.cumsum = cumsum
|
||||
|
||||
def cumprod(x,dim=0):
|
||||
def cumprod(x,dim=None):
|
||||
x = jt.log(x)
|
||||
x = cumsum(x,dim=dim)
|
||||
return jt.exp(x)
|
||||
|
@ -749,26 +781,26 @@ def triu_(x,diagonal=0):
|
|||
jt.Var.triu_ = triu_
|
||||
|
||||
def print_tree(now, max_memory_size, prefix1, prefix2, build_by):
|
||||
def format_size(s):
|
||||
def format_size(s, end='B'):
|
||||
if (s < 1024):
|
||||
s = str(s)
|
||||
return s + ' B'
|
||||
return s + ' '+end
|
||||
|
||||
if (s < 1024*1024):
|
||||
s = format(s/1024, '.2f')
|
||||
return s + ' KB'
|
||||
return s + ' K'+end
|
||||
|
||||
if (s < 1024*1024*1024):
|
||||
s = format(s/1024/1024, '.2f')
|
||||
return s + ' MB'
|
||||
return s + ' M'+end
|
||||
|
||||
s = format(s/1024/1024/1024, '.2f')
|
||||
return s + ' GB'
|
||||
return s + ' G'+end
|
||||
|
||||
out = ''
|
||||
tab = ' '
|
||||
out += prefix1+now['name']+'('+now['type']+')\n'
|
||||
out += prefix2+'['+format_size(now['size'])+'; '+format(now['size']/max_memory_size*100, '.2f')+'%]\n'
|
||||
out += prefix2+'['+format_size(now['size'])+'; '+format(now['size']/max_memory_size*100, '.2f')+'%; cnt:'+format_size(now['cnt'],'') + ']\n'
|
||||
if (build_by == 0):
|
||||
for p in now['path']:
|
||||
out += prefix2+p+'\n'
|
||||
|
@ -834,7 +866,7 @@ Output::
|
|||
vars_ = vars_[1:]
|
||||
for v_ in vars_:
|
||||
v__ = v_.split(div2)
|
||||
var = {'size':int(v__[1]), 'stack':[]}
|
||||
var = {'size':int(v__[1]), 'stack':[], 'cnt':1}
|
||||
v__ = v__[2:-1]
|
||||
for s_ in v__:
|
||||
s__ = s_.split(div3)
|
||||
|
@ -842,7 +874,7 @@ Output::
|
|||
var['stack'].append(s)
|
||||
vars.append(var)
|
||||
if (build_by == 0): # build tree by name
|
||||
tree = {'name':'root', "children":[], 'size':0, 'path':[], 'type':''}
|
||||
tree = {'name':'root', "children":[], 'size':0, 'cnt':1, 'path':[], 'type':''}
|
||||
|
||||
def find_child(now, key):
|
||||
for c in now['children']:
|
||||
|
@ -852,6 +884,7 @@ Output::
|
|||
for v in vars:
|
||||
now = tree
|
||||
now['size'] += v['size']
|
||||
now['cnt'] += v['cnt']
|
||||
for s in v['stack']:
|
||||
ch = find_child(now, s['name'])
|
||||
if (ch is not None):
|
||||
|
@ -860,12 +893,13 @@ Output::
|
|||
assert(ch['type']==s['type'])
|
||||
now = ch
|
||||
now['size'] += v['size']
|
||||
now['cnt'] += v['cnt']
|
||||
else:
|
||||
now_ = {'name':s['name'], "children":[], 'size':v['size'], 'path':[s['path']], 'type':s['type']}
|
||||
now_ = {'name':s['name'], "children":[], 'size':v['size'], 'cnt':v['cnt'], 'path':[s['path']], 'type':s['type']}
|
||||
now['children'].append(now_)
|
||||
now = now_
|
||||
elif (build_by == 1): # build tree by path
|
||||
tree = {'name':'root', "children":[], 'size':0, 'path':'_root_', 'type':''}
|
||||
tree = {'name':'root', "children":[], 'size':0, 'cnt':0, 'path':'_root_', 'type':''}
|
||||
|
||||
def find_child(now, key):
|
||||
for c in now['children']:
|
||||
|
@ -875,13 +909,15 @@ Output::
|
|||
for v in vars:
|
||||
now = tree
|
||||
now['size'] += v['size']
|
||||
now['cnt'] += v['cnt']
|
||||
for s in v['stack']:
|
||||
ch = find_child(now, s['path'])
|
||||
if (ch is not None):
|
||||
now = ch
|
||||
now['size'] += v['size']
|
||||
now['cnt'] += v['cnt']
|
||||
else:
|
||||
now_ = {'name':s['name'], "children":[], 'size':v['size'], 'path':s['path'], 'type':s['type']}
|
||||
now_ = {'name':s['name'], "children":[], 'size':v['size'], 'cnt':v['cnt'], 'path':s['path'], 'type':s['type']}
|
||||
now['children'].append(now_)
|
||||
now = now_
|
||||
else:
|
||||
|
@ -1063,14 +1099,17 @@ def randperm(n, dtype="int32"):
|
|||
index, _ = jt.argsort(key)
|
||||
return index.cast(dtype)
|
||||
|
||||
def set_global_seed(seed):
|
||||
def set_global_seed(seed, different_seed_for_mpi=True):
|
||||
''' Sets the seeds of the random number generators of Python, numpy and jittor,
|
||||
simultaneously.
|
||||
|
||||
.. note::
|
||||
Jittor also gurantees each worker of jittor.dataset.Dataset to hold a different seed,
|
||||
which is global_seed ^ worker_id ^ 1234.
|
||||
also gurantees each process hold a different seed which using mpi,
|
||||
which is (global_seed ^ (worker_id*1167)) ^ 1234 + jt.rank * 2591
|
||||
'''
|
||||
if (different_seed_for_mpi):
|
||||
seed = seed + jt.rank * 2591
|
||||
import random
|
||||
random.seed(seed)
|
||||
jt.set_seed(seed)
|
||||
|
@ -1081,6 +1120,9 @@ def set_global_seed(seed):
|
|||
except:
|
||||
pass
|
||||
|
||||
import time
|
||||
set_global_seed(int(time.time() * 1000000) % 100000007)
|
||||
|
||||
def searchsorted(sorted, values, right=False):
|
||||
"""
|
||||
Find the indices from the innermost dimension of `sorted` for each `values`.
|
||||
|
@ -1269,3 +1311,348 @@ jt.Var.roll = roll
|
|||
def safe_log(x):
|
||||
return jt.safe_clip(x, 1e-30, 1e30).log()
|
||||
jt.Var.safe_log = safe_log
|
||||
|
||||
class _CTCLossFunction(jt.Function):
|
||||
def execute(self, log_probs, targets, input_lengths, target_lengths, blank=0, zero_infinity=False):
|
||||
self.blank = blank
|
||||
T, N, C = log_probs.shape
|
||||
_N, S = targets.shape
|
||||
assert _N == N
|
||||
log_alpha = jt.full([T,N,S*2+1], -1e30)
|
||||
result = jt.empty((N,))
|
||||
jt.code([log_probs, targets, input_lengths, target_lengths], [log_alpha, result], cpu_src=f"""
|
||||
constexpr int blank = {blank};
|
||||
for (int i=0; i<in0_shape1; i++) {{
|
||||
int input_len = @in2(i);
|
||||
int target_len = @in3(i);
|
||||
@out0(0,i,0) = @in0(0,i,blank);
|
||||
if (target_len)
|
||||
@out0(0,i,1) = @in0(0,i,@in1(i,0));
|
||||
for (int j=1; j<input_len; j++)
|
||||
for (int k=0; k<target_len*2+1; k++) {{
|
||||
int target = k%2 ? @in1(i,k/2) : blank;
|
||||
int target_2 = target;
|
||||
if (k>1 && k%2) target_2 = @in1(i,k/2-1);
|
||||
out_type l1 = @out0(j-1,i,k);
|
||||
out_type l2 = -1e30;
|
||||
if (k>0) l2 = @out0(j-1,i,k-1);
|
||||
out_type l3 = -1e30;
|
||||
if (k>1 && target_2 != target)
|
||||
l3 = @out0(j-1,i,k-2);
|
||||
out_type m = std::max(l1, std::max(l2, l3));
|
||||
@out0(j,i,k) = std::log(
|
||||
std::exp(l1-m) +
|
||||
std::exp(l2-m) +
|
||||
std::exp(l3-m)
|
||||
) + m + @in0(j,i,target);
|
||||
}}
|
||||
if (input_len==0)
|
||||
@out1(i) = @out0(0,i,0);
|
||||
else {{
|
||||
out_type l1 = @out0(input_len-1, i, target_len*2);
|
||||
out_type l2 = -1e30;
|
||||
if (target_len)
|
||||
l2 = @out0(input_len-1, i, target_len*2-1);
|
||||
out_type m = std::max(l1, l2);
|
||||
out_type log_likelihood = std::log(std::exp(l1-m)+std::exp(l2-m))+m;
|
||||
@out1(i) = -log_likelihood;
|
||||
}}
|
||||
}}
|
||||
""", cuda_src=f"""
|
||||
__global__ void kernel(@ARGS_DEF) {{
|
||||
@PRECALC;
|
||||
constexpr int blank = {blank};
|
||||
for (int i=blockIdx.x; i<in0_shape1; i+=gridDim.x) {{
|
||||
int input_len = @in2(i);
|
||||
int target_len = @in3(i);
|
||||
@out0(0,i,0) = @in0(0,i,blank);
|
||||
if (target_len)
|
||||
@out0(0,i,1) = @in0(0,i,@in1(i,0));
|
||||
for (int j=1; j<input_len; j++)
|
||||
for (int k=threadIdx.x; k-threadIdx.x<target_len*2+1; k+=blockDim.x) {{
|
||||
__syncthreads();
|
||||
if (k>=target_len*2+1)
|
||||
continue;
|
||||
int target = k%2 ? @in1(i,k/2) : blank;
|
||||
int target_2 = target;
|
||||
if (k>1 && k%2) target_2 = @in1(i,k/2-1);
|
||||
out_type l1 = @out0(j-1,i,k);
|
||||
out_type l2 = -1e30;
|
||||
if (k>0) l2 = @out0(j-1,i,k-1);
|
||||
out_type l3 = -1e30;
|
||||
if (k>1 && target_2 != target)
|
||||
l3 = @out0(j-1,i,k-2);
|
||||
out_type m = ::max(l1, ::max(l2, l3));
|
||||
@out0(j,i,k) = ::log(
|
||||
::exp(l1-m) +
|
||||
::exp(l2-m) +
|
||||
::exp(l3-m)
|
||||
) + m + @in0(j,i,target);
|
||||
}}
|
||||
__syncthreads();
|
||||
if (input_len==0)
|
||||
@out1(i) = @out0(0,i,0);
|
||||
else {{
|
||||
out_type l1 = @out0(input_len-1, i, target_len*2);
|
||||
out_type l2 = -1e30;
|
||||
if (target_len)
|
||||
l2 = @out0(input_len-1, i, target_len*2-1);
|
||||
out_type m = ::max(l1, l2);
|
||||
out_type log_likelihood = ::log(::exp(l1-m)+::exp(l2-m))+m;
|
||||
@out1(i) = -log_likelihood;
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
kernel<<<std::min(in0_shape1, 1024), std::min(in1_shape1*2+1, 1024)>>>(@ARGS);
|
||||
""")
|
||||
self.saved_var = [log_probs, targets, input_lengths, target_lengths, log_alpha, result]
|
||||
return result
|
||||
|
||||
def grad(self, dout):
|
||||
blank = self.blank
|
||||
inputs = self.saved_var + [dout]
|
||||
dlog_probs = jt.zeros_like(inputs[0])
|
||||
dlog_alpha = jt.zeros_like(inputs[4])
|
||||
jt.code(inputs, [dlog_probs, dlog_alpha], cpu_src=f"""
|
||||
constexpr int blank = {blank};
|
||||
for (int i=0; i<in0_shape1; i++) {{
|
||||
int input_len = @in2(i);
|
||||
int target_len = @in3(i);
|
||||
if (input_len==0)
|
||||
// write out1 --> read in6
|
||||
// out1(i) = out0(0,i,0);
|
||||
@out1(0,i,0) = @in6(i);
|
||||
else {{
|
||||
out_type l1 = @in4(input_len-1, i, target_len*2);
|
||||
out_type l2 = -1e30;
|
||||
if (target_len)
|
||||
l2 = @in4(input_len-1, i, target_len*2-1);
|
||||
out_type m = std::max(l1, l2);
|
||||
// out_type log_likelihood = std::log(std::exp(l1-m)+std::exp(l2-m))+m;
|
||||
// out1(i) = -log_likelihood;
|
||||
out_type l1_exp = std::exp(l1-m);
|
||||
out_type l2_exp = std::exp(l2-m);
|
||||
out_type sumexp = l1_exp + l2_exp;
|
||||
|
||||
out_type dlog_likelihood = -@in6(i);
|
||||
out_type dl1 = dlog_likelihood * l1_exp / sumexp;
|
||||
out_type dl2 = dlog_likelihood * l2_exp / sumexp;
|
||||
|
||||
@out1(input_len-1, i, target_len*2) = dl1;
|
||||
if (target_len)
|
||||
@out1(input_len-1, i, target_len*2-1) = dl2;
|
||||
}}
|
||||
for (int j=input_len-1; j>0; j--)
|
||||
for (int k=0; k<target_len*2+1; k++) {{
|
||||
int target = k%2 ? @in1(i,k/2) : blank;
|
||||
int target_2 = target;
|
||||
if (k>1 && k%2) target_2 = @in1(i,k/2-1);
|
||||
out_type l1 = @in4(j-1,i,k);
|
||||
out_type l2 = -1e30;
|
||||
if (k>0) l2 = @in4(j-1,i,k-1);
|
||||
out_type l3 = -1e30;
|
||||
if (k>1 && target_2 != target)
|
||||
l3 = @in4(j-1,i,k-2);
|
||||
out_type m = std::max(l1, std::max(l2, l3));
|
||||
out_type l1_exp = std::exp(l1-m);
|
||||
out_type l2_exp = std::exp(l2-m);
|
||||
out_type l3_exp = std::exp(l3-m);
|
||||
out_type sumexp = l1_exp + l2_exp + l3_exp;
|
||||
out_type dalpha = @out1(j,i,k);
|
||||
|
||||
@out0(j,i,target) += dalpha;
|
||||
|
||||
@out1(j-1,i,k) += dalpha * l1_exp / sumexp;
|
||||
if (k>0)
|
||||
@out1(j-1,i,k-1) += dalpha * l2_exp / sumexp;
|
||||
if (k>1 && target_2 != target)
|
||||
@out1(j-1,i,k-2) += dalpha * l3_exp / sumexp;
|
||||
}}
|
||||
// read in0 -> white out0
|
||||
// write out0 ->read out1
|
||||
// out0(0,i,0) = in0(0,i,blank);
|
||||
@out0(0,i,blank) += @out1(0,i,0);
|
||||
if (target_len)
|
||||
@out0(0,i,@in1(i,0)) += @out1(0,i,1);
|
||||
}}
|
||||
""", cuda_src=f"""
|
||||
__global__ void kernel(@ARGS_DEF) {{
|
||||
@PRECALC;
|
||||
constexpr int blank = {blank};
|
||||
for (int i=blockIdx.x; i<in0_shape1; i+=gridDim.x) {{
|
||||
int input_len = @in2(i);
|
||||
int target_len = @in3(i);
|
||||
if (input_len==0)
|
||||
// write out1 --> read in6
|
||||
// out1(i) = out0(0,i,0);
|
||||
@out1(0,i,0) = @in6(i);
|
||||
else {{
|
||||
out_type l1 = @in4(input_len-1, i, target_len*2);
|
||||
out_type l2 = -1e30;
|
||||
if (target_len)
|
||||
l2 = @in4(input_len-1, i, target_len*2-1);
|
||||
out_type m = ::max(l1, l2);
|
||||
// out_type log_likelihood = ::log(::exp(l1-m)+::exp(l2-m))+m;
|
||||
// out1(i) = -log_likelihood;
|
||||
out_type l1_exp = ::exp(l1-m);
|
||||
out_type l2_exp = ::exp(l2-m);
|
||||
out_type sumexp = l1_exp + l2_exp;
|
||||
|
||||
out_type dlog_likelihood = -@in6(i);
|
||||
out_type dl1 = dlog_likelihood * l1_exp / sumexp;
|
||||
out_type dl2 = dlog_likelihood * l2_exp / sumexp;
|
||||
|
||||
@out1(input_len-1, i, target_len*2) = dl1;
|
||||
if (target_len)
|
||||
@out1(input_len-1, i, target_len*2-1) = dl2;
|
||||
}}
|
||||
for (int j=input_len-1; j>0; j--)
|
||||
for (int k=threadIdx.x; k-threadIdx.x<target_len*2+1; k+=blockDim.x) {{
|
||||
__syncthreads();
|
||||
if (k>=target_len*2+1)
|
||||
continue;
|
||||
int target = k%2 ? @in1(i,k/2) : blank;
|
||||
int target_2 = target;
|
||||
if (k>1 && k%2) target_2 = @in1(i,k/2-1);
|
||||
out_type l1 = @in4(j-1,i,k);
|
||||
out_type l2 = -1e30;
|
||||
if (k>0) l2 = @in4(j-1,i,k-1);
|
||||
out_type l3 = -1e30;
|
||||
if (k>1 && target_2 != target)
|
||||
l3 = @in4(j-1,i,k-2);
|
||||
out_type m = ::max(l1, ::max(l2, l3));
|
||||
out_type l1_exp = ::exp(l1-m);
|
||||
out_type l2_exp = ::exp(l2-m);
|
||||
out_type l3_exp = ::exp(l3-m);
|
||||
out_type sumexp = l1_exp + l2_exp + l3_exp;
|
||||
out_type dalpha = @out1(j,i,k);
|
||||
|
||||
atomicAdd(&@out0(j,i,target), dalpha);
|
||||
|
||||
atomicAdd(&@out1(j-1,i,k), dalpha * l1_exp / sumexp);
|
||||
if (k>0)
|
||||
atomicAdd(&@out1(j-1,i,k-1), dalpha * l2_exp / sumexp);
|
||||
if (k>1 && target_2 != target)
|
||||
atomicAdd(&@out1(j-1,i,k-2), dalpha * l3_exp / sumexp);
|
||||
}}
|
||||
// read in0 -> white out0
|
||||
// write out0 ->read out1
|
||||
// out0(0,i,0) = in0(0,i,blank);
|
||||
__syncthreads();
|
||||
if (threadIdx.x==0) {{
|
||||
@out0(0,i,blank) += @out1(0,i,0);
|
||||
if (target_len)
|
||||
@out0(0,i,@in1(i,0)) += @out1(0,i,1);
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
kernel<<<std::min(in0_shape1, 1024), std::min(in1_shape1*2+1, 1024)>>>(@ARGS);
|
||||
""")
|
||||
return (dlog_probs,)
|
||||
|
||||
|
||||
def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean', zero_infinity=False):
|
||||
'''The Connectionist Temporal Classification loss.
|
||||
|
||||
|
||||
Reference:
|
||||
A. Graves et al.: Connectionist Temporal Classification:
|
||||
Labelling Unsegmented Sequence Data with Recurrent Neural Networks:
|
||||
https://www.cs.toronto.edu/~graves/icml_2006.pdf
|
||||
|
||||
Input:
|
||||
|
||||
log_probs: shape is [T, N, C], T is the sequence length, N is the batch size, C is the class number.
|
||||
targets: shape is [N, S], N is the batch size, S is the target sequence length, element should between [0,C).
|
||||
input_lengths: shape is [N], which represents the length of input, element should between [0,T].
|
||||
target_lengths: shape is N, which represents the length of target, element should between [0,S].
|
||||
blank (int, default 0): blank label index
|
||||
reduction (string): reduce batch loss,
|
||||
if reduction is none, it will return (N,) array,
|
||||
if reduction is mean or sum, it will return one scalar
|
||||
zero_infinity (bool, default False):
|
||||
zero_infinity for grad
|
||||
|
||||
Example:
|
||||
|
||||
import jittor as jt
|
||||
T = 50 # Input sequence length
|
||||
C = 20 # Number of classes (including blank)
|
||||
N = 16 # Batch size
|
||||
S = 30 # Target sequence length of longest target in batch (padding length)
|
||||
S_min = 10 # Minimum target length, for demonstration purposes
|
||||
|
||||
input = jt.randn(T, N, C).log_softmax(2)
|
||||
# Initialize random batch of targets (0 = blank, 1:C = classes)
|
||||
target = jt.randint(low=1, high=C, shape=(N, S), dtype=jt.int)
|
||||
|
||||
input_lengths = jt.full((N,), T, dtype=jt.int)
|
||||
target_lengths = jt.randint(low=S_min, high=S+1, shape=(N,), dtype=jt.int)
|
||||
loss = jt.ctc_loss(input, target, input_lengths, target_lengths)
|
||||
|
||||
dinput = jt.grad(loss, input)
|
||||
|
||||
'''
|
||||
result = _CTCLossFunction.apply(log_probs, targets, input_lengths, target_lengths, blank, zero_infinity)
|
||||
if reduction=="mean":
|
||||
return result.mean()
|
||||
elif reduction=="sum":
|
||||
return result.sum()
|
||||
assert reduction=="none"
|
||||
return result
|
||||
|
||||
|
||||
class CTCLoss(jt.Module):
|
||||
'''The Connectionist Temporal Classification loss.
|
||||
|
||||
|
||||
Reference:
|
||||
A. Graves et al.: Connectionist Temporal Classification:
|
||||
Labelling Unsegmented Sequence Data with Recurrent Neural Networks:
|
||||
https://www.cs.toronto.edu/~graves/icml_2006.pdf
|
||||
|
||||
|
||||
Args:
|
||||
|
||||
blank (int, default 0): blank label index
|
||||
reduction (string): reduce batch loss,
|
||||
if reduction is none, it will return (N,) array,
|
||||
if reduction is mean or sum, it will return one scalar
|
||||
zero_infinity (bool, default False):
|
||||
zero_infinity for grad
|
||||
|
||||
Input:
|
||||
|
||||
log_probs: shape is [T, N, C], T is the sequence length, N is the batch size, C is the class number.
|
||||
targets: shape is [N, S], N is the batch size, S is the target sequence length, element should between [0,C).
|
||||
input_lengths: shape is [N], which represents the length of input, element should between [0,T].
|
||||
target_lengths: shape is N, which represents the length of target, element should between [0,S].
|
||||
|
||||
Example:
|
||||
|
||||
import jittor as jt
|
||||
T = 50 # Input sequence length
|
||||
C = 20 # Number of classes (including blank)
|
||||
N = 16 # Batch size
|
||||
S = 30 # Target sequence length of longest target in batch (padding length)
|
||||
S_min = 10 # Minimum target length, for demonstration purposes
|
||||
|
||||
input = jt.randn(T, N, C).log_softmax(2)
|
||||
# Initialize random batch of targets (0 = blank, 1:C = classes)
|
||||
target = jt.randint(low=1, high=C, shape=(N, S), dtype=jt.int)
|
||||
|
||||
input_lengths = jt.full((N,), T, dtype=jt.int)
|
||||
target_lengths = jt.randint(low=S_min, high=S+1, shape=(N,), dtype=jt.int)
|
||||
ctc_loss = jt.CTCLoss()
|
||||
loss = ctc_loss(input, target, input_lengths, target_lengths)
|
||||
|
||||
dinput = jt.grad(loss, input)
|
||||
|
||||
'''
|
||||
def __init__(self, blank=0, reduction='mean', zero_infinity=False):
|
||||
self.blank = blank
|
||||
self.reduction = reduction
|
||||
self.zero_infinity = zero_infinity
|
||||
|
||||
def execute(self, log_probs, targets, input_lengths, target_lengths):
|
||||
return ctc_loss(log_probs, targets, input_lengths, target_lengths, self.blank, self.reduction, self.zero_infinity)
|
|
@ -7,12 +7,12 @@ import math
|
|||
|
||||
|
||||
model_urls = {
|
||||
'res2net50_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_26w_4s-06e79181.pth',
|
||||
'res2net50_48w_2s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_48w_2s-afed724a.pth',
|
||||
'res2net50_14w_8s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_14w_8s-6527dddc.pth',
|
||||
'res2net50_26w_6s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_26w_6s-19041792.pth',
|
||||
'res2net50_26w_8s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_26w_8s-2c7c9f12.pth',
|
||||
'res2net101_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net101_26w_4s-02a759a1.pth',
|
||||
'res2net50_14w_8s': 'https://cloud.tsinghua.edu.cn/f/2543e4b5646d40a1afa9/?dl=1&fname=/res2net50_14w_8s.pkl',
|
||||
'res2net50_26w_4s': 'https://cloud.tsinghua.edu.cn/f/927fead9c9884f769d88/?dl=1&fname=/res2net50_26w_4s.pkl',
|
||||
'res2net50_26w_6s': 'https://cloud.tsinghua.edu.cn/f/067875edf6ca488ba83e/?dl=1&fname=/res2net50_26w_6s.pkl',
|
||||
'res2net50_26w_8s': 'https://cloud.tsinghua.edu.cn/f/ce1230155a2c4352bf17/?dl=1&fname=/res2net50_26w_8s.pkl',
|
||||
'res2net50_48w_2s': 'https://cloud.tsinghua.edu.cn/f/b8a4df2b2cb64500b869/?dl=1&fname=/res2net50_48w_2s.pkl',
|
||||
'res2net101_26w_4s': 'https://cloud.tsinghua.edu.cn/f/b85283bf572649d288bb/?dl=1&fname=/res2net101_26w_4s.pkl',
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -123,7 +123,7 @@ Example::
|
|||
shape = []
|
||||
len_c = max(len_a, len_b)
|
||||
(n, m), (m_, k) = a.shape[-2:], b.shape[-2:]
|
||||
assert m == m_, f"dimension not match, a.shape:{a.shape}, b.shape:{a.shape}"
|
||||
assert m == m_, f"dimension not match, a.shape:{a.shape}, b.shape:{b.shape}"
|
||||
# a: [..., n, m]
|
||||
# b: [..., m, k]
|
||||
# cc:[..., n, m, k]
|
||||
|
@ -141,7 +141,7 @@ Example::
|
|||
an = a.shape[ai] if ai>=0 else 1
|
||||
bn = b.shape[bi] if bi>=0 else 1
|
||||
if an!=1 and bn!=1:
|
||||
assert an == bn, f"dimension not match, a.shape:{a.shape}, b.shape:{a.shape}"
|
||||
assert an == bn, f"dimension not match, a.shape:{a.shape}, b.shape:{b.shape}"
|
||||
cn = max(an, bn)
|
||||
shape.append(cn)
|
||||
shape.extend([n, m, k])
|
||||
|
@ -341,16 +341,20 @@ def softmax(x, dim = None):
|
|||
x = (x-x.max(dim, keepdims=True)).exp()
|
||||
ret = x / x.sum(dim, keepdims=True)
|
||||
return ret
|
||||
jt.Var.softmax = softmax
|
||||
|
||||
def log_softmax(x,dim=None):
|
||||
x = softmax(x,dim=dim)
|
||||
return jt.log(x)
|
||||
jt.Var.log_softmax = log_softmax
|
||||
|
||||
def log_sigmoid(x):
|
||||
return jt.log(jt.sigmoid(x))
|
||||
jt.Var.log_sigmoid = log_sigmoid
|
||||
|
||||
def logsumexp(x, dim, keepdim=False):
|
||||
return x.exp().sum(dim, keepdim).log()
|
||||
jt.Var.logsumexp = logsumexp
|
||||
|
||||
class Identity(Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
@ -1612,7 +1616,8 @@ class Sequential(Module):
|
|||
return
|
||||
parents.append(self)
|
||||
for k,v in self.layers.items():
|
||||
v.dfs(parents, k, callback, callback_leave)
|
||||
if isinstance(v, Module):
|
||||
v.dfs(parents, k, callback, callback_leave)
|
||||
parents.pop()
|
||||
if callback_leave:
|
||||
callback_leave(parents, k, self, n_children)
|
||||
|
|
|
@ -30,8 +30,10 @@ int current_seed;
|
|||
EXTERN_LIB list<VarPtr> fetcher;
|
||||
EXTERN_LIB list<VarPtr> fetcher_to_free;
|
||||
EXTERN_LIB vector<void(*)()> cleanup_callback;
|
||||
EXTERN_LIB bool exited;
|
||||
|
||||
void cleanup() {
|
||||
exited = true;
|
||||
fetcher_to_free.clear();
|
||||
fetcher.clear();
|
||||
for (auto cb : cleanup_callback)
|
||||
|
|
|
@ -17,7 +17,9 @@ void* AlignedAllocator::alloc(size_t size, size_t& allocation) {
|
|||
#ifndef _WIN32
|
||||
#ifdef __APPLE__
|
||||
size += 32-size%32;
|
||||
#endif
|
||||
// low version of mac don't have aligned_alloc
|
||||
return new char[size];
|
||||
#else
|
||||
return aligned_alloc(alignment, size);
|
||||
#else
|
||||
return _aligned_malloc(size, alignment);
|
||||
|
@ -28,8 +30,12 @@ void AlignedAllocator::free(void* mem_ptr, size_t size, const size_t& allocation
|
|||
#ifdef _WIN32
|
||||
_aligned_free(mem_ptr);
|
||||
#else
|
||||
#ifdef __APPLE__
|
||||
delete[] (char*)mem_ptr;
|
||||
#else
|
||||
::free(mem_ptr);
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
} // jittor
|
|
@ -241,7 +241,7 @@ void segfault_sigaction(int signal, siginfo_t *si, void *arg) {
|
|||
return;
|
||||
}
|
||||
if (signal == SIGCHLD) {
|
||||
if (si->si_code != CLD_EXITED && si->si_status != SIGTERM) {
|
||||
if (si->si_code != CLD_EXITED && si->si_status != SIGTERM && _pid == getpid()) {
|
||||
LOGe << "Caught SIGCHLD"
|
||||
<< "si_errno:" << si->si_errno
|
||||
<< "si_code:" << si->si_code
|
||||
|
@ -259,6 +259,7 @@ void segfault_sigaction(int signal, siginfo_t *si, void *arg) {
|
|||
exited = true;
|
||||
do_exit();
|
||||
}
|
||||
if (exited) do_exit();
|
||||
std::cerr << "Caught segfault at address " << si->si_addr << ", "
|
||||
<< "thread_name: '" << thread_name << "', flush log..." << std::endl;
|
||||
std::cerr.flush();
|
||||
|
@ -556,9 +557,11 @@ void system_with_check(const char* cmd, const char* cwd) {
|
|||
std::thread log_thread(log_main);
|
||||
#endif
|
||||
|
||||
int log_exit = 0;
|
||||
|
||||
void log_exiting() {
|
||||
if (exited) return;
|
||||
exited = true;
|
||||
if (log_exit) return;
|
||||
log_exit = true;
|
||||
for (auto cb : cleanup_callback)
|
||||
cb();
|
||||
cleanup_callback.clear();
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
import unittest
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
def expect_error(func):
|
||||
try:
|
||||
|
@ -86,6 +87,17 @@ class TestCore(unittest.TestCase):
|
|||
c = np.matmul(a, b)
|
||||
jtc = jt.matmul(jt.array(a), jt.array(b)).data
|
||||
assert np.all(jtc == c)
|
||||
|
||||
def test_save_load_sub_module(self):
|
||||
class Net(jt.Module):
|
||||
def __init__(self):
|
||||
self.conv1 = jt.nn.Conv(3,3,3)
|
||||
net = Net()
|
||||
assert list(net.named_parameters().keys()) == ['conv1.weight', 'conv1.bias']
|
||||
assert list(net.conv1.named_parameters().keys()) == ['weight', 'bias']
|
||||
pkl_name = os.path.join(jt.flags.cache_path, "sub.pkl")
|
||||
net.conv1.save(pkl_name)
|
||||
net.conv1.load(pkl_name)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -0,0 +1,103 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
# Maintainers:
|
||||
# Guoye Yang <498731903@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
#
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
from jittor import compile_extern
|
||||
if jt.has_cuda:
|
||||
from jittor.compile_extern import cublas_ops, cudnn_ops, cub_ops
|
||||
else:
|
||||
cublas_ops = cudnn_ops = cub_ops = None
|
||||
|
||||
|
||||
def test_forward(shape, dim=None):
|
||||
x = jt.random(shape)
|
||||
y = jt.numpy_cumsum(x)
|
||||
y_ = jt.cub_cumsum(x)
|
||||
assert(np.allclose(y.data, y_.data))
|
||||
|
||||
def test_backward(shape, dim=None):
|
||||
x = jt.random(shape)
|
||||
z = jt.random(shape)
|
||||
|
||||
y = jt.numpy_cumsum(x)
|
||||
loss = (y * z).sum()
|
||||
grad = jt.grad(loss, x)
|
||||
|
||||
y_ = jt.cub_cumsum(x)
|
||||
loss_ = (y_ * z).sum()
|
||||
grad_ = jt.grad(loss_, x)
|
||||
assert(np.allclose(grad.data, grad_.data))
|
||||
|
||||
class TestCubCumsumOp(unittest.TestCase):
|
||||
@unittest.skipIf(cub_ops==None, "Not use cub, Skip")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_1d(self):
|
||||
test_forward([20])
|
||||
test_forward([3007])
|
||||
test_forward([3007], 0)
|
||||
test_forward([3007], -1)
|
||||
|
||||
@unittest.skipIf(cub_ops==None, "Not use cub, Skip")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_1d_backward(self):
|
||||
test_backward([20])
|
||||
test_backward([3007])
|
||||
test_backward([3007], 0)
|
||||
test_backward([3007], -1)
|
||||
|
||||
@unittest.skipIf(cub_ops==None, "Not use cub, Skip")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_2d(self):
|
||||
test_forward([5,5])
|
||||
test_forward([2000, 3007])
|
||||
test_forward([2000, 3007], 1)
|
||||
test_forward([2000, 3007], -1)
|
||||
|
||||
@unittest.skipIf(cub_ops==None, "Not use cub, Skip")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_2d_backward(self):
|
||||
test_backward([5,5])
|
||||
test_backward([2000, 3007])
|
||||
test_backward([2000, 3007], 1)
|
||||
test_backward([2000, 3007], -1)
|
||||
|
||||
@unittest.skipIf(cub_ops==None, "Not use cub, Skip")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_nd(self):
|
||||
test_forward([5,6,7,8], 0)
|
||||
test_forward([5,6,7,8], 1)
|
||||
test_forward([5,6,7,8], 2)
|
||||
test_forward([5,6,7,8], 3)
|
||||
test_forward([5,6,7,8], -1)
|
||||
test_forward([16,14,14,2048], 0)
|
||||
test_forward([16,14,14,2048], 1)
|
||||
test_forward([16,14,14,2048], 2)
|
||||
test_forward([16,14,14,2048], 3)
|
||||
test_forward([16,14,14,2048], -1)
|
||||
test_forward([16,14,14,2047], 3)
|
||||
|
||||
@unittest.skipIf(cub_ops==None, "Not use cub, Skip")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_nd_backward(self):
|
||||
test_backward([5,6,7,8], 0)
|
||||
test_backward([5,6,7,8], 1)
|
||||
test_backward([5,6,7,8], 2)
|
||||
test_backward([5,6,7,8], 3)
|
||||
test_backward([5,6,7,8], -1)
|
||||
test_backward([16,14,14,2048], 0)
|
||||
test_backward([16,14,14,2048], 1)
|
||||
test_backward([16,14,14,2048], 2)
|
||||
test_backward([16,14,14,2048], 3)
|
||||
test_backward([16,14,14,2048], -1)
|
||||
test_backward([16,14,14,2047], 3)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -267,6 +267,45 @@ if __name__ == "__main__":
|
|||
assert "quick exit" in s
|
||||
|
||||
|
||||
@unittest.skipIf(not jt.compile_extern.has_mpi, "no mpi found")
|
||||
def test_dataset_shuffle_mpi(self):
|
||||
src = """
|
||||
import jittor as jt
|
||||
from jittor.dataset import Dataset
|
||||
import numpy as np
|
||||
|
||||
class YourDataset(Dataset):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.set_attrs(total_len=160, shuffle=True)
|
||||
|
||||
def __getitem__(self, k):
|
||||
return k
|
||||
|
||||
dataset = YourDataset()
|
||||
dataset.set_attrs(num_workers=2)
|
||||
|
||||
for d in dataset:
|
||||
for a in d:
|
||||
print("CHECK: ", a.item())
|
||||
"""
|
||||
fname = os.path.join(jt.flags.cache_path, "test_dataset_shuffle_mpi.py")
|
||||
with open(fname, 'w') as f:
|
||||
f.write(src)
|
||||
import subprocess as sp
|
||||
import sys
|
||||
cmd = sys.executable + " " + fname
|
||||
mpirun_path = jt.compile_extern.mpicc_path.replace("mpicc", "mpirun")
|
||||
cmd = mpirun_path + " -np 2 " + cmd
|
||||
print(cmd)
|
||||
r = sp.run(cmd, shell=True, stdout=sp.PIPE, stderr=sp.PIPE)
|
||||
s = r.stdout.decode()
|
||||
# print(s)
|
||||
st = set([ l for l in s.splitlines() if l.startswith("CHECK:") ])
|
||||
assert r.returncode == 0
|
||||
# print(st)
|
||||
assert len(st) == 160, len(st)
|
||||
|
||||
def test_children_died2(self):
|
||||
src = """
|
||||
import jittor as jt
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
# Maintainers:
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
#
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import jittor as jt
|
||||
from jittor.dataset.mnist import EMNIST, MNIST
|
||||
import jittor.transform as transform
|
||||
|
||||
@unittest.skipIf(True, f"skip emnist test")
|
||||
class TestEMNIST(unittest.TestCase):
|
||||
def test_emnist(self):
|
||||
import pylab as pl
|
||||
# emnist_dataset = EMNIST()
|
||||
emnist_dataset = EMNIST()
|
||||
for imgs, labels in emnist_dataset:
|
||||
print(imgs.shape, labels.shape)
|
||||
print(labels.max(), labels.min())
|
||||
# imgs = imgs.transpose(0,1,3,2).transpose(1,2,0,3)[0].reshape(28, -1)
|
||||
imgs = imgs.transpose(1,2,0,3)[0].reshape(28, -1)
|
||||
print(labels)
|
||||
pl.imshow(imgs), pl.show()
|
||||
break
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -47,7 +47,17 @@ class TestLoadPth(unittest.TestCase):
|
|||
jt_out = jt_model(jt_img)
|
||||
torch_out = torch_model(torch_img)
|
||||
print(np.max(np.abs(jt_out.fetch_sync() - torch_out.detach().numpy())))
|
||||
assert np.max(np.abs(jt_out.fetch_sync() - torch_out.detach().numpy())) < 1e-4
|
||||
assert np.max(np.abs(jt_out.fetch_sync() - torch_out.detach().numpy())) < 1e-3
|
||||
|
||||
pth_name = os.path.join(jt.flags.cache_path, "x.pth")
|
||||
torch.save(torch_model.state_dict, pth_name)
|
||||
jt_model.load(pth_name)
|
||||
|
||||
# output
|
||||
jt_out = jt_model(jt_img)
|
||||
# torch_out = torch_model(torch_img)
|
||||
print(np.max(np.abs(jt_out.fetch_sync() - torch_out.detach().numpy())))
|
||||
assert np.max(np.abs(jt_out.fetch_sync() - torch_out.detach().numpy())) < 1e-3
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -1,3 +1,13 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
# Maintainers:
|
||||
# Zheng-Ning Liu <lzhengning@gmail.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
#
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
|
||||
|
|
|
@ -165,6 +165,8 @@ jt.mkl_ops.mkl_conv(x, w, 1, 1, 2, 2).sync()
|
|||
n += 1
|
||||
assert n == 2
|
||||
assert list(x.keys()) == [0,1]
|
||||
p = x.parameters()
|
||||
assert len(p)==0
|
||||
|
||||
# def test_res2net(self):
|
||||
# import jittor.models
|
||||
|
|
|
@ -159,6 +159,86 @@ class TestPad(unittest.TestCase):
|
|||
out.detach().numpy(), output.data,
|
||||
atol=1e-4)
|
||||
|
||||
def test_ctc_loss(self):
|
||||
def check(T,C,N,S,S_min):
|
||||
jt.set_global_seed(0)
|
||||
|
||||
# Initialize random batch of input vectors, for *size = (T,N,C)
|
||||
input = jt.randn(T, N, C).log_softmax(2)
|
||||
# input = -jt.ones((T, N, C))
|
||||
# input[0,0,1] += 0.01
|
||||
|
||||
# Initialize random batch of targets (0 = blank, 1:C = classes)
|
||||
target = jt.randint(low=1, high=C, shape=(N, S), dtype=jt.int)
|
||||
_input_jt = input
|
||||
|
||||
input_lengths = jt.full((N,), T, dtype=jt.int)
|
||||
target_lengths = jt.randint(low=S_min, high=S+1, shape=(N,), dtype=jt.int)
|
||||
# ctc_loss = nn.CTCLoss()
|
||||
loss = jt.ctc_loss(input, target, input_lengths, target_lengths, reduction='none')
|
||||
_loss_jt = loss
|
||||
|
||||
loss_jt = loss.numpy()
|
||||
|
||||
input = torch.Tensor(input.numpy()).detach().requires_grad_()
|
||||
input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
|
||||
target_lengths = torch.LongTensor(target_lengths.numpy())
|
||||
input_lengths = torch.LongTensor(input_lengths.numpy())
|
||||
target = torch.LongTensor(target.numpy())
|
||||
loss = tnn.CTCLoss(reduction='none')(input, target, input_lengths, target_lengths)
|
||||
np.testing.assert_allclose(loss.detach().numpy(), loss_jt, rtol=1e-5, atol=1e-5)
|
||||
|
||||
dinput_jt = jt.grad(_loss_jt, _input_jt)
|
||||
dinput_jt.sync()
|
||||
|
||||
loss.sum().backward()
|
||||
# print(input.grad)
|
||||
# print(dinput_jt)
|
||||
# print(loss)
|
||||
|
||||
def check_gpu_with_cpu(T,C,N,S,S_min):
|
||||
jt.set_global_seed(1)
|
||||
|
||||
# Initialize random batch of input vectors, for *size = (T,N,C)
|
||||
input = jt.randn(T, N, C).log_softmax(2)
|
||||
# input = -jt.ones((T, N, C))
|
||||
# input[0,0,1] += 0.01
|
||||
|
||||
# Initialize random batch of targets (0 = blank, 1:C = classes)
|
||||
target = jt.randint(low=1, high=C, shape=(N, S), dtype=jt.int)
|
||||
_input_jt = input
|
||||
|
||||
input_lengths = jt.full((N,), T, dtype=jt.int)
|
||||
target_lengths = jt.randint(low=S_min, high=S+1, shape=(N,), dtype=jt.int)
|
||||
# ctc_loss = nn.CTCLoss()
|
||||
loss = jt.ctc_loss(input, target, input_lengths, target_lengths, reduction='none')
|
||||
_loss_jt = loss
|
||||
|
||||
loss_jt = loss.numpy()
|
||||
|
||||
dinput_jt = jt.grad(_loss_jt, _input_jt)
|
||||
dinput_jt.sync()
|
||||
|
||||
with jt.flag_scope(use_cuda=1):
|
||||
input = input.copy()
|
||||
target = target.copy()
|
||||
input_lengths = input_lengths.copy()
|
||||
target_lengths = target_lengths.copy()
|
||||
loss = jt.ctc_loss(input, target, input_lengths, target_lengths, reduction='none')
|
||||
grad = jt.grad(loss, input)
|
||||
np.testing.assert_allclose(_loss_jt.numpy(), loss.numpy(), atol=1e-5, rtol=1e-5)
|
||||
np.testing.assert_allclose(dinput_jt.numpy(), grad.numpy(), atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
check(2,2,1,1,1)
|
||||
check(50,20,16,30,10)
|
||||
|
||||
if jt.has_cuda:
|
||||
with jt.flag_scope(use_cuda=1):
|
||||
check(2,2,1,1,1)
|
||||
check(50,20,16,30,10)
|
||||
check_gpu_with_cpu(50,20,16,30,10)
|
||||
|
||||
class TestOther(unittest.TestCase):
|
||||
def test_save(self):
|
||||
pp = [1,2,jt.array([1,2,3]), {"a":[1,2,3], "b":jt.array([1,2,3])}]
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
# Maintainers:
|
||||
# Zheng-Ning Liu <lzhengning@gmail.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
#
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
|
||||
import unittest
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TestRepeatOp(unittest.TestCase):
|
||||
def test_repeat(self):
|
||||
np_a = np.arange(5)
|
||||
jt_a = jt.array(np_a)
|
||||
|
||||
np_b = np.tile(np_a, (2, 3))
|
||||
jt_b = jt.repeat(jt_a, (2, 3))
|
||||
|
||||
assert np.allclose(np_b, jt_b.data)
|
||||
|
||||
np_b = np.tile(np_a, (2, 3, 1))
|
||||
jt_b = jt.repeat(jt_a, (2, 3, 1))
|
||||
|
||||
assert np.allclose(np_b, jt_b.data)
|
||||
|
||||
np_a = np.arange(24).reshape(2, 3, 4)
|
||||
jt_a = jt.array(np_a)
|
||||
|
||||
np_b = np.tile(np_a, (2, 3))
|
||||
jt_b = jt.repeat(jt_a, (2, 3))
|
||||
|
||||
assert np.allclose(np_b, jt_b.data)
|
||||
|
||||
|
||||
def test_highdim(self):
|
||||
np_a = np.arange(64).reshape(2, 2, 2, 2, 2, 2)
|
||||
jt_a = jt.array(np_a)
|
||||
|
||||
np_b = np.tile(np_a, (2, 3))
|
||||
jt_b = jt.repeat(jt_a, (2, 3))
|
||||
|
||||
assert np.allclose(np_b, jt_b.data)
|
||||
|
||||
np_b = np.tile(np_a, (2, 1, 1, 3))
|
||||
jt_b = jt.repeat(jt_a, (2, 1, 1, 3))
|
||||
|
||||
assert np.allclose(np_b, jt_b.data)
|
||||
|
||||
np_b = np.tile(np_a, (2, 1, 1, 1, 3, 1))
|
||||
jt_b = jt.repeat(jt_a, (2, 1, 1, 1, 3, 1))
|
||||
|
||||
assert np.allclose(np_b, jt_b.data)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -954,6 +954,22 @@ class Tester(unittest.TestCase):
|
|||
transform.ToTensor(),
|
||||
])(img)
|
||||
|
||||
def test_not_pil_image(self):
|
||||
img = jt.random((30,40,3))
|
||||
result = transform.Compose([
|
||||
transform.RandomAffine(20),
|
||||
transform.ToTensor(),
|
||||
])(img)
|
||||
|
||||
img = jt.random((30,40,3))
|
||||
result = transform.Compose([
|
||||
transform.ToPILImage(),
|
||||
transform.Gray(),
|
||||
transform.Resize(20),
|
||||
transform.ToTensor(),
|
||||
])(img)
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -152,7 +152,9 @@ class Crop:
|
|||
self.left = left
|
||||
self.height = height
|
||||
self.width = width
|
||||
def __call__(self, img):
|
||||
def __call__(self, img:Image.Image):
|
||||
if not isinstance(img, Image.Image):
|
||||
img = to_pil_image(img)
|
||||
return crop(img, self.top, self.left, self.height, self.width)
|
||||
|
||||
|
||||
|
@ -181,6 +183,8 @@ class RandomCropAndResize:
|
|||
self.interpolation = interpolation
|
||||
|
||||
def __call__(self, img:Image.Image):
|
||||
if not isinstance(img, Image.Image):
|
||||
img = to_pil_image(img)
|
||||
width, height = img.size
|
||||
scale = self.scale
|
||||
ratio = self.ratio
|
||||
|
@ -363,6 +367,8 @@ class RandomHorizontalFlip:
|
|||
self.p = p
|
||||
|
||||
def __call__(self, img:Image.Image):
|
||||
if not isinstance(img, Image.Image):
|
||||
img = to_pil_image(img)
|
||||
if random.random() < self.p:
|
||||
return img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
return img
|
||||
|
@ -384,6 +390,8 @@ class CenterCrop:
|
|||
self.size = _setup_size(size, error_msg="If size is a sequence, it should have 2 values")
|
||||
|
||||
def __call__(self, img:Image.Image):
|
||||
if not isinstance(img, Image.Image):
|
||||
img = to_pil_image(img)
|
||||
width, height = img.size
|
||||
return crop(img, (height - self.size[0]) / 2, (width - self.size[1]) / 2, self.size[0], self.size[1])
|
||||
|
||||
|
@ -682,6 +690,8 @@ class Resize:
|
|||
self.size = _setup_size(size, error_msg="If size is a sequence, it should have 2 values")
|
||||
self.mode = mode
|
||||
def __call__(self, img:Image.Image):
|
||||
if not isinstance(img, Image.Image):
|
||||
img = to_pil_image(img)
|
||||
return resize(img, self.size, self.mode)
|
||||
|
||||
class Gray:
|
||||
|
@ -697,6 +707,8 @@ class Gray:
|
|||
self.num_output_channels = num_output_channels
|
||||
|
||||
def __call__(self, img:Image.Image):
|
||||
if not isinstance(img, Image.Image):
|
||||
img = to_pil_image(img)
|
||||
img = np.float32(img.convert('L')) / np.float32(255.0)
|
||||
if self.num_output_channels == 1:
|
||||
return img[np.newaxis, :]
|
||||
|
@ -720,7 +732,9 @@ class RandomGray:
|
|||
def __init__(self, p=0.1):
|
||||
self.p = p
|
||||
|
||||
def __call__(self, img: Image.Image):
|
||||
def __call__(self, img:Image.Image):
|
||||
if not isinstance(img, Image.Image):
|
||||
img = to_pil_image(img)
|
||||
num_output_channels = _get_image_num_channels(img)
|
||||
if random.random() < self.p:
|
||||
return gray(img, num_output_channels=num_output_channels)
|
||||
|
@ -742,6 +756,8 @@ class RandomCrop:
|
|||
def __init__(self, size):
|
||||
self.size = _setup_size(size, error_msg="If size is a sequence, it should have 2 values")
|
||||
def __call__(self, img:Image.Image):
|
||||
if not isinstance(img, Image.Image):
|
||||
img = to_pil_image(img)
|
||||
width, height = img.size
|
||||
assert self.size[0] <= height and self.size[1] <= width, f"crop size exceeds the input image in RandomCrop, {(self.size, height, width)}"
|
||||
top = np.random.randint(0,height-self.size[0]+1)
|
||||
|
@ -835,7 +851,9 @@ class RandomVerticalFlip:
|
|||
def __init__(self, p=0.5):
|
||||
self.p = p
|
||||
|
||||
def __call__(self, img: Image.Image):
|
||||
def __call__(self, img:Image.Image):
|
||||
if not isinstance(img, Image.Image):
|
||||
img = to_pil_image(img)
|
||||
if random.random() < self.p:
|
||||
return vflip(img)
|
||||
return img
|
||||
|
@ -918,13 +936,15 @@ class ColorJitter:
|
|||
|
||||
return transform
|
||||
|
||||
def __call__(self, img):
|
||||
def __call__(self, img:Image.Image):
|
||||
"""
|
||||
Args::
|
||||
[in] img (PIL Image): Input image.
|
||||
Returns::
|
||||
[out] PIL Image: Color jittered image.
|
||||
"""
|
||||
if not isinstance(img, Image.Image):
|
||||
img = to_pil_image(img)
|
||||
transform = self._get_transform(self.brightness, self.contrast, self.saturation, self.hue)
|
||||
|
||||
return transform(img)
|
||||
|
@ -1002,7 +1022,7 @@ class RandomPerspective(object):
|
|||
self.interpolation = interpolation
|
||||
self.distortion_scale = distortion_scale
|
||||
|
||||
def __call__(self, img):
|
||||
def __call__(self, img:Image.Image):
|
||||
"""
|
||||
Args:
|
||||
img (PIL Image): Image to be Perspectively transformed.
|
||||
|
@ -1011,7 +1031,7 @@ class RandomPerspective(object):
|
|||
PIL Image: Random perspectivley transformed image.
|
||||
"""
|
||||
if not isinstance(img, Image.Image):
|
||||
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
|
||||
img = to_pil_image(img)
|
||||
|
||||
if random.random() < self.p:
|
||||
width, height = img.size
|
||||
|
@ -1119,7 +1139,7 @@ class RandomResizedCrop(object):
|
|||
j = (width - w) // 2
|
||||
return i, j, h, w
|
||||
|
||||
def __call__(self, img):
|
||||
def __call__(self, img:Image.Image):
|
||||
"""
|
||||
Args:
|
||||
img (PIL Image): Image to be cropped and resized.
|
||||
|
@ -1127,6 +1147,8 @@ class RandomResizedCrop(object):
|
|||
Returns:
|
||||
PIL Image: Randomly cropped and resized image.
|
||||
"""
|
||||
if not isinstance(img, Image.Image):
|
||||
img = to_pil_image(img)
|
||||
i, j, h, w = self.get_params(img, self.scale, self.ratio)
|
||||
return F_pil.resized_crop(img, i, j, h, w, self.size, self.interpolation)
|
||||
|
||||
|
@ -1174,7 +1196,9 @@ class FiveCrop(object):
|
|||
assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
|
||||
self.size = size
|
||||
|
||||
def __call__(self, img):
|
||||
def __call__(self, img:Image.Image):
|
||||
if not isinstance(img, Image.Image):
|
||||
img = to_pil_image(img)
|
||||
return F_pil.five_crop(img, self.size)
|
||||
|
||||
def __repr__(self):
|
||||
|
@ -1217,7 +1241,9 @@ class TenCrop(object):
|
|||
self.size = size
|
||||
self.vertical_flip = vertical_flip
|
||||
|
||||
def __call__(self, img):
|
||||
def __call__(self, img:Image.Image):
|
||||
if not isinstance(img, Image.Image):
|
||||
img = to_pil_image(img)
|
||||
return F_pil.ten_crop(img, self.size, self.vertical_flip)
|
||||
|
||||
def __repr__(self):
|
||||
|
@ -1275,7 +1301,7 @@ class RandomRotation(object):
|
|||
|
||||
return angle
|
||||
|
||||
def __call__(self, img):
|
||||
def __call__(self, img:Image.Image):
|
||||
"""
|
||||
Args:
|
||||
img (PIL Image): Image to be rotated.
|
||||
|
@ -1283,7 +1309,8 @@ class RandomRotation(object):
|
|||
Returns:
|
||||
PIL Image: Rotated image.
|
||||
"""
|
||||
|
||||
if not isinstance(img, Image.Image):
|
||||
img = to_pil_image(img)
|
||||
angle = self.get_params(self.degrees)
|
||||
|
||||
return F_pil.rotate(img, angle, self.resample, self.expand, self.center, self.fill)
|
||||
|
@ -1405,13 +1432,15 @@ class RandomAffine(object):
|
|||
|
||||
return angle, translations, scale, shear
|
||||
|
||||
def __call__(self, img):
|
||||
def __call__(self, img:Image.Image):
|
||||
"""
|
||||
img (PIL Image): Image to be transformed.
|
||||
|
||||
Returns:
|
||||
PIL Image: Affine transformed image.
|
||||
"""
|
||||
if not isinstance(img, Image.Image):
|
||||
img = to_pil_image(img)
|
||||
ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img.size)
|
||||
return F_pil.affine(img, *ret, resample=self.resample, fillcolor=self.fillcolor)
|
||||
|
||||
|
|
Loading…
Reference in New Issue