This commit is contained in:
cxjyxx_me 2020-12-30 11:26:49 +08:00
commit 1be344526c
463 changed files with 2662 additions and 1157 deletions

View File

@ -119,7 +119,9 @@ Jittor 一共提供三种方式安装: pip安装, 一键脚本安装 和 手动
```bash
sudo apt install python3.7-dev libomp-dev
sudo python3.7 -m pip install git+https://github.com/Jittor/jittor.git
python3.7 -m pip install jittor
# or install from github(latest version)
# python3.7 -m pip install git+https://github.com/Jittor/jittor.git
python3.7 -m jittor.test.test_example
```
@ -217,13 +219,13 @@ jt.flags.use_cuda = 1
```
### 可选步骤五:进行完整测试
### 可选步骤五:测试训练Resnet18
要检查Jittor的完整性您可以运行完整的测试。
要检查Jittor的完整性您可以运行Resnet18训练测试。
```bash
python3.7 -m jittor.test -v
python3.7 -m jittor.test.test_resnet
```
如果这些测试失败请为我们报告错误我们十分欢迎您为Jittor做出贡献^ _ ^
@ -360,10 +362,29 @@ Jittor还很年轻。 它可能存在错误和问题。 请在我们的错误跟
QQ 群761222083
## 团队
Jittor目前由来自[清华大学计算机图形学组](https://cg.cs.tsinghua.edu.cn/)的梁盾,杨国烨,杨国炜,周文洋和国孟昊等博士生维护。 如果您也对Jittor感兴趣并希望对其进行改进请加入我们
Jittor目前由[清华大学计算机图形学组](https://cg.cs.tsinghua.edu.cn/)维护。 如果您也对Jittor感兴趣并希望对其进行改进请加入我们
## 引用
```
@article{hu2020jittor,
title={Jittor: a novel deep learning framework with meta-operators and unified graph execution},
author={Hu, Shi-Min and Liang, Dun and Yang, Guo-Ye and Yang, Guo-Wei and Zhou, Wen-Yang},
journal={Information Sciences},
volume={63},
number={222103},
pages={1--222103},
year={2020}
}
```
## 版权声明

View File

@ -116,7 +116,9 @@ Jittor offers three ways to install: pip, script or manual.
```bash
sudo apt install python3.7-dev libomp-dev
sudo python3.7 -m pip install git+https://github.com/Jittor/jittor.git
python3.7 -m pip install jittor
# or install from github(latest version)
# python3.7 -m pip install git+https://github.com/Jittor/jittor.git
python3.7 -m jittor.test.test_example
```
@ -210,14 +212,14 @@ import jittor as jt
jt.flags.use_cuda = 1
```
### Optional Step 5: Run full tests
### Optional Step 5: Test Resnet18 training
To check the integrity of Jittor, you can run full tests.
To check the integrity of Jittor, you can run Resnet18 training test.
```bash
python3.7 -m jittor.test -v
python3.7 -m jittor.test.test_resnet
```
if those tests are failed, please report bugs for us, and feel free to contribute ^_^
@ -353,12 +355,32 @@ Email: jittor@qq.com
File an issue: https://github.com/Jittor/jittor/issues
QQ Group: 761222083
<img src="https://cg.cs.tsinghua.edu.cn/jittor/images/news/2020-12-8-21-19-1_2_2/fig4.png" width="200"/>
## The Team
Jittor is currently maintained by Dun Liang, Guo-Ye Yang, Guo-Wei Yang, Wen-Yang Zhou and Meng-Hao Guo etc. from the [Tsinghua CSCG Group](https://cg.cs.tsinghua.edu.cn/). If you are also interested in Jittor and want to improve it, Please join us!
Jittor is currently maintained by the [Tsinghua CSCG Group](https://cg.cs.tsinghua.edu.cn/). If you are also interested in Jittor and want to improve it, Please join us!
## Citation
```
@article{hu2020jittor,
title={Jittor: a novel deep learning framework with meta-operators and unified graph execution},
author={Hu, Shi-Min and Liang, Dun and Yang, Guo-Ye and Yang, Guo-Wei and Zhou, Wen-Yang},
journal={Information Sciences},
volume={63},
number={222103},
pages={1--222103},
year={2020}
}
```
## License

View File

@ -151,7 +151,9 @@ Jittor 一共提供三种方式安装: pip安装, 一键脚本安装 和 手动
```bash
sudo apt install python3.7-dev libomp-dev
sudo python3.7 -m pip install git+https://github.com/Jittor/jittor.git
python3.7 -m pip install jittor
# or install from github(latest version)
# python3.7 -m pip install git+https://github.com/Jittor/jittor.git
python3.7 -m jittor.test.test_example
```
@ -453,13 +455,35 @@ Email: jittor@qq.com
File an issue: https://github.com/Jittor/jittor/issues
QQ Group: 761222083
QQ 群761222083
<img src="https://cg.cs.tsinghua.edu.cn/jittor/images/news/2020-12-8-21-19-1_2_2/fig4.png" width="200"/>
## The Team
## 团队
Jittor is currently maintained by Dun Liang, Guo-Ye Yang, Guo-Wei Yang, Wen-Yang Zhou and Meng-Hao Guo etc. from the [Tsinghua CSCG Group](https://cg.cs.tsinghua.edu.cn/). If you are also interested in Jittor and want to improve it, Please join us!
Jittor is currently maintained by the [Tsinghua CSCG Group](https://cg.cs.tsinghua.edu.cn/). If you are also interested in Jittor and want to improve it, Please join us!
Jittor目前由来自[清华大学计算机图形学组](https://cg.cs.tsinghua.edu.cn/)的梁盾,杨国烨,杨国炜,周文洋和国孟昊等博士生维护。 如果您也对Jittor感兴趣并希望对其进行改进请加入我们
Jittor目前由[清华大学计算机图形学组](https://cg.cs.tsinghua.edu.cn/)维护。 如果您也对Jittor感兴趣并希望对其进行改进请加入我们
## Citation
## 引用
```
@article{hu2020jittor,
title={Jittor: a novel deep learning framework with meta-operators and unified graph execution},
author={Hu, Shi-Min and Liang, Dun and Yang, Guo-Ye and Yang, Guo-Wei and Zhou, Wen-Yang},
journal={Information Sciences},
volume={63},
number={222103},
pages={1--222103},
year={2020}
}
```
## License

View File

@ -1,8 +1,9 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors:
// Copyright (c) 2020 Jittor. All Rights Reserved.
// Maintainers:
// Guoye Yang <498731903@qq.com>
// Dun Liang <randonlang@gmail.com>.
// All Rights Reserved.
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************

View File

@ -1,8 +1,9 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors:
// Copyright (c) 2020 Jittor. All Rights Reserved.
// Maintainers:
// Guoye Yang <498731903@qq.com>
// Dun Liang <randonlang@gmail.com>.
// All Rights Reserved.
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************

View File

@ -1,8 +1,9 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors:
// Copyright (c) 2020 Jittor. All Rights Reserved.
// Maintainers:
// Guoye Yang <498731903@qq.com>
// Dun Liang <randonlang@gmail.com>.
// All Rights Reserved.
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************

View File

@ -1,8 +1,9 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors:
// Copyright (c) 2020 Jittor. All Rights Reserved.
// Maintainers:
// Guoye Yang <498731903@qq.com>
// Dun Liang <randonlang@gmail.com>.
// All Rights Reserved.
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************

View File

@ -1,5 +1,6 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
// Copyright (c) 2020 Jittor. All Rights Reserved.
// Maintainers: Dun Liang <randonlang@gmail.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************

View File

@ -1,5 +1,6 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
// Copyright (c) 2020 Jittor. All Rights Reserved.
// Maintainers: Dun Liang <randonlang@gmail.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************

View File

@ -1,8 +1,9 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors:
// Copyright (c) 2020 Jittor. All Rights Reserved.
// Maintainers:
// Xiangli Li <1905692338@qq.com>
// Dun Liang <randonlang@gmail.com>.
// All Rights Reserved.
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************

View File

@ -1,8 +1,9 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors:
// Copyright (c) 2020 Jittor. All Rights Reserved.
// Maintainers:
// Xiangli Li <1905692338@qq.com>
// Dun Liang <randonlang@gmail.com>.
// All Rights Reserved.
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************

View File

@ -1,8 +1,9 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors:
// Copyright (c) 2020 Jittor. All Rights Reserved.
// Maintainers:
// Guoye Yang <498731903@qq.com>
// Dun Liang <randonlang@gmail.com>.
// All Rights Reserved.
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************

View File

@ -1,8 +1,9 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors:
// Copyright (c) 2020 Jittor. All Rights Reserved.
// Maintainers:
// Meng-Hao Guo <guomenghao1997@gmail.com>
// Dun Liang <randonlang@gmail.com>.
// All Rights Reserved.
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************

View File

@ -1,8 +1,9 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors:
// Copyright (c) 2020 Jittor. All Rights Reserved.
// Maintainers:
// Meng-Hao Guo <guomenghao1997@gmail.com>
// Dun Liang <randonlang@gmail.com>.
// All Rights Reserved.
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************

View File

@ -1,8 +1,9 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors:
// Copyright (c) 2020 Jittor. All Rights Reserved.
// Maintainers:
// Guowei Yang <471184555@qq.com>
// Dun Liang <randonlang@gmail.com>.
// All Rights Reserved.
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************

View File

@ -1,8 +1,9 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors:
// Copyright (c) 2020 Jittor. All Rights Reserved.
// Maintainers:
// Guoye Yang <498731903@qq.com>
// Dun Liang <randonlang@gmail.com>.
// All Rights Reserved.
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************

View File

@ -1,5 +1,6 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
// Copyright (c) 2020 Jittor. All Rights Reserved.
// Maintainers: Dun Liang <randonlang@gmail.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************

View File

@ -1,5 +1,6 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
// Copyright (c) 2020 Jittor. All Rights Reserved.
// Maintainers: Dun Liang <randonlang@gmail.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************

View File

@ -1,8 +1,9 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors:
// Copyright (c) 2020 Jittor. All Rights Reserved.
// Maintainers:
// Guoye Yang <498731903@qq.com>
// Dun Liang <randonlang@gmail.com>.
// All Rights Reserved.
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************

View File

@ -1,5 +1,6 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
// Copyright (c) 2020 Jittor. All Rights Reserved.
// Maintainers: Dun Liang <randonlang@gmail.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************

View File

@ -1,8 +1,9 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors:
// Copyright (c) 2020 Jittor. All Rights Reserved.
// Maintainers:
// Dun Liang <randonlang@gmail.com>
// Guowei Yang <471184555@qq.com>
// All Rights Reserved.
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************

View File

@ -1,8 +1,9 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors:
// Copyright (c) 2020 Jittor. All Rights Reserved.
// Maintainers:
// Dun Liang <randonlang@gmail.com>
// Guowei Yang <471184555@qq.com>
// All Rights Reserved.
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************

View File

@ -1,8 +1,9 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors:
// Copyright (c) 2020 Jittor. All Rights Reserved.
// Maintainers:
// Dun Liang <randonlang@gmail.com>
// Guowei Yang <471184555@qq.com>
// All Rights Reserved.
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************

View File

@ -1,8 +1,9 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors:
// Copyright (c) 2020 Jittor. All Rights Reserved.
// Maintainers:
// Dun Liang <randonlang@gmail.com>
// Guowei Yang <471184555@qq.com>
// All Rights Reserved.
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************

View File

@ -1,6 +1,7 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>.
// All Rights Reserved.
// Copyright (c) 2020 Jittor. All Rights Reserved.
// Maintainers: Dun Liang <randonlang@gmail.com>.
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************

View File

@ -1,5 +1,6 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
// Copyright (c) 2020 Jittor. All Rights Reserved.
// Maintainers: Dun Liang <randonlang@gmail.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************

View File

@ -1,5 +1,6 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
// Copyright (c) 2020 Jittor. All Rights Reserved.
// Maintainers: Dun Liang <randonlang@gmail.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************

View File

@ -1,5 +1,6 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
// Copyright (c) 2020 Jittor. All Rights Reserved.
// Maintainers: Dun Liang <randonlang@gmail.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************

View File

@ -1,5 +1,6 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
// Copyright (c) 2020 Jittor. All Rights Reserved.
// Maintainers: Dun Liang <randonlang@gmail.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************

View File

@ -1,8 +1,9 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors:
// Copyright (c) 2020 Jittor. All Rights Reserved.
// Maintainers:
// Guoye Yang <498731903@qq.com>
// Dun Liang <randonlang@gmail.com>.
// All Rights Reserved.
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************

View File

@ -1,5 +1,6 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
// Copyright (c) 2020 Jittor. All Rights Reserved.
// Maintainers: Dun Liang <randonlang@gmail.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************

View File

@ -1,8 +1,9 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors:
// Copyright (c) 2020 Jittor. All Rights Reserved.
// Maintainers:
// Guoye Yang <498731903@qq.com>.
// Dun Liang <randonlang@gmail.com>.
// All Rights Reserved.
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************

View File

@ -1,8 +1,9 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors:
// Copyright (c) 2020 Jittor. All Rights Reserved.
// Maintainers:
// Guoye Yang <498731903@qq.com>
// Dun Liang <randonlang@gmail.com>.
// All Rights Reserved.
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************

View File

@ -1,8 +1,9 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors:
// Copyright (c) 2020 Jittor. All Rights Reserved.
// Maintainers:
// Guoye Yang <498731903@qq.com>.
// Dun Liang <randonlang@gmail.com>.
// All Rights Reserved.
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************

View File

@ -1,8 +1,9 @@
// ***************************************************************
// Copyright (c) 2020 Jittor.
// Authors:
// All Rights Reserved.
// Maintainers:
// Dun Liang <randonlang@gmail.com>.
// All Rights Reserved.
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************

View File

@ -1,8 +1,9 @@
// ***************************************************************
// Copyright (c) 2020 Jittor.
// Authors:
// All Rights Reserved.
// Maintainers:
// Dun Liang <randonlang@gmail.com>.
// All Rights Reserved.
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************

View File

@ -1,8 +1,9 @@
// ***************************************************************
// Copyright (c) 2020 Jittor.
// Authors:
// All Rights Reserved.
// Maintainers:
// Dun Liang <randonlang@gmail.com>.
// All Rights Reserved.
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************

View File

@ -15,7 +15,6 @@
#include <cuda_runtime.h>
#include "helper_cuda.h"
#ifdef _CUFFT_H_
// cuFFT API errors
const char *_cudaGetErrorEnum(cufftResult error) {

View File

@ -1,9 +1,10 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors:
// Copyright (c) 2020 Jittor. All Rights Reserved.
// Maintainers:
// Guowei Yang <471184555@qq.com>
// Guoye Yang <498731903@qq.com>
// Dun Liang <randonlang@gmail.com>.
// All Rights Reserved.
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************

View File

@ -1,8 +1,9 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors:
// Copyright (c) 2020 Jittor. All Rights Reserved.
// Maintainers:
// Guowei Yang <471184555@qq.com>
// Dun Liang <randonlang@gmail.com>.
// All Rights Reserved.
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************

View File

@ -1,9 +1,10 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors:
// Copyright (c) 2020 Jittor. All Rights Reserved.
// Maintainers:
// Guowei Yang <471184555@qq.com>
// Guoye Yang <498731903@qq.com>
// Dun Liang <randonlang@gmail.com>.
// All Rights Reserved.
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************

View File

@ -1,8 +1,9 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors:
// Copyright (c) 2020 Jittor. All Rights Reserved.
// Maintainers:
// Guowei Yang <471184555@qq.com>
// Dun Liang <randonlang@gmail.com>.
// All Rights Reserved.
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************

View File

@ -1,8 +1,9 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors:
// Copyright (c) 2020 Jittor. All Rights Reserved.
// Maintainers:
// Guowei Yang <471184555@qq.com>
// Dun Liang <randonlang@gmail.com>.
// All Rights Reserved.
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************

View File

@ -1,8 +1,9 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors:
// Copyright (c) 2020 Jittor. All Rights Reserved.
// Maintainers:
// Guowei Yang <471184555@qq.com>
// Dun Liang <randonlang@gmail.com>.
// All Rights Reserved.
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************

View File

@ -1,8 +1,9 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors:
// Copyright (c) 2020 Jittor. All Rights Reserved.
// Maintainers:
// Guowei Yang <471184555@qq.com>
// Dun Liang <randonlang@gmail.com>.
// All Rights Reserved.
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************

View File

@ -1,8 +1,9 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors:
// Copyright (c) 2020 Jittor. All Rights Reserved.
// Maintainers:
// Guoye Yang <498731903@qq.com>
// Dun Liang <randonlang@gmail.com>.
// All Rights Reserved.
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************

View File

@ -1,5 +1,6 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
// Copyright (c) 2020 Jittor. All Rights Reserved.
// Maintainers: Dun Liang <randonlang@gmail.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************

View File

@ -1,5 +1,6 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
// Copyright (c) 2020 Jittor. All Rights Reserved.
// Maintainers: Dun Liang <randonlang@gmail.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************

View File

@ -1,8 +1,9 @@
// ***************************************************************
// Copyright (c) 2020 Jittor.
// Authors:
// All Rights Reserved.
// Maintainers:
// Dun Liang <randonlang@gmail.com>.
// All Rights Reserved.
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************

View File

@ -1,8 +1,9 @@
// ***************************************************************
// Copyright (c) 2020 Jittor.
// Authors:
// All Rights Reserved.
// Maintainers:
// Dun Liang <randonlang@gmail.com>.
// All Rights Reserved.
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************

View File

@ -1,13 +1,14 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Copyright (c) 2020 Jittor. All Rights Reserved.
# Maintainers:
# Dun Liang <randonlang@gmail.com>.
# Meng-Hao Guo <guomenghao1997@gmail.com>
#
# All Rights Reserved.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.2.0'
__version__ = '1.2.2.12'
from . import lock
with lock.lock_scope():
ori_int = int
@ -33,9 +34,38 @@ from collections import OrderedDict
from collections.abc import Sequence, Mapping
import types
import pickle
import sys
import hashlib
import sys, os
import traceback
def safepickle(obj, path):
s = pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
checksum = hashlib.sha1(s).digest()
s += bytes(checksum)
s += b"HCAJSLHD"
with open(path, 'wb') as f:
f.write(s)
def safeunpickle(path):
if path.startswith("jittorhub://"):
path = path.replace("jittorhub://", "https://cg.cs.tsinghua.edu.cn/jittor/assets/build/checkpoints/")
if path.startswith("https:") or path.startswith("http:"):
base = path.split("/")[-1]
fname = os.path.join(compiler.ck_path, base)
from jittor.utils.misc import download_url_to_local
download_url_to_local(path, base, compiler.ck_path, None)
path = fname
with open(path, "rb") as f:
s = f.read()
if not s.endswith(b"HCAJSLHD"):
return pickle.loads(s)
checksum = s[-28:-8]
s = s[:-28]
if hashlib.sha1(s).digest() != checksum:
raise ValueError("Pickle checksum does not match! path: "+path)
return pickle.loads(s)
class _call_no_record_scope:
def __enter__(self): pass
def __exit__(self, *exc): pass
@ -307,9 +337,9 @@ def flatten(input, start_dim=0, end_dim=-1):
return input.reshape(out_shape)
Var.flatten = flatten
def detach_inplace(x):
return x.swap(x.stop_grad().clone())
Var.start_grad = Var.detach_inplace = detach_inplace
def start_grad(x):
return x._update(x)
Var.detach_inplace = Var.start_grad = start_grad
def detach(x):
return x.detach()
@ -436,10 +466,30 @@ def display_memory_info():
core.display_memory_info(fileline)
def load(path):
pkl_file = open(path, 'rb')
model_dict = pickle.load(pkl_file)
if path.endswith(".pth"):
try:
dirty_fix_pytorch_runtime_error()
import torch
except:
raise RuntimeError("pytorch need to be installed when load pth format.")
model_dict = torch.load(path, map_location=torch.device('cpu'))
else:
model_dict = safeunpickle(path)
return model_dict
def save(params_dict, path):
def dfs(x):
if isinstance(x, list):
for i in range(len(x)):
x[i] = dfs(x[i])
elif isinstance(x, dict):
for k in x:
x[k] = dfs(x[k])
elif isinstance(x, Var):
return x.numpy()
return x
safepickle(dfs(params_dict), path)
def _uniq(x):
a = set()
b = []
@ -503,6 +553,7 @@ class Module:
def callback(parents, k, v, n):
stack.append(str(k))
for k2, p in v.__dict__.items():
if k2.startswith("_"): continue
if isinstance(p, Var):
ps.append(p)
p.name(".".join(stack[1:]+[str(k2)]))
@ -560,6 +611,21 @@ class Module:
return ret
self.__class__.__call__ = new_call
def register_pre_forward_hook(self, func):
cls = self.__class__
self.__fhook2__ = func
if hasattr(cls, "__hooked2__"):
return
cls.__hooked2__ = True
origin_call = cls.__call__
def new_call(self, *args, **kw):
if hasattr(self, "__fhook2__"):
if len(kw):
self.__fhook2__(self, args, kw)
else:
self.__fhook2__(self, args)
return origin_call(self, *args, **kw)
self.__class__.__call__ = new_call
def children(self):
cd = []
@ -632,20 +698,10 @@ class Module:
params_dict = {}
for p in params:
params_dict[p.name()] = p.data
with open(path, 'wb') as f:
pickle.dump(params_dict, f, pickle.HIGHEST_PROTOCOL)
safepickle(params_dict, path)
def load(self, path):
if path.endswith(".pth"):
try:
dirty_fix_pytorch_runtime_error()
import torch
except:
raise RuntimeError("pytorch need to be installed when load pth format.")
self.load_parameters(torch.load(path, map_location=torch.device('cpu')))
return
with open(path, 'rb') as f:
self.load_parameters(pickle.load(f))
self.load_parameters(load(path))
def eval(self):
def callback(parents, k, v, n):
@ -790,6 +846,11 @@ can also be None)::
def dfs(self, parents, k, callback, callback_leave=None):
pass
@classmethod
def apply(cls, *args, **kw):
func = cls()
return func(*args, **kw)
def make_module(func, exec_n_args=1):
class MakeModule(Module):
@ -889,7 +950,10 @@ def format(v, spec):
return v.item().__format__(spec)
Var.__format__ = format
def get_len(var):
return var.shape[0]
Var.__len__ = get_len
int = int32
Var.int = Var.int32
float = float32
@ -906,3 +970,27 @@ from . import contrib
from . import numpy2cupy
from .contrib import concat
from .misc import *
from . import sparse
def randn(*size, dtype="float32", requires_grad=False):
if isinstance(size, tuple) and isinstance(size[0], tuple): size = size[0]
arr = jt.random(size, dtype, "normal")
if not requires_grad: return arr.stop_grad()
return arr
def rand(*size, dtype="float32", requires_grad=False):
if isinstance(size, tuple) and isinstance(size[0], tuple): size = size[0]
arr = jt.random(size, dtype)
if not requires_grad: return arr.stop_grad()
return arr
def normal(mean, std, size=None, dtype="float32"):
if size is None:
if isinstance(mean, Var) and isinstance(std, Var):
assert mean.shape == std.shape
size = mean.shape
else:
if isinstance(mean, Var): size = mean.shape
if isinstance(std, Var): size = std.shape
return jt.init.gauss(size, dtype, mean, std)

View File

@ -1,9 +1,10 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Copyright (c) 2020 Jittor. All Rights Reserved.
# Maintainers:
# Guowei Yang <471184555@qq.com>
# Dun Liang <randonlang@gmail.com>.
#
# All Rights Reserved.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************

View File

@ -1,5 +1,6 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
# Copyright (c) 2020 Jittor. All Rights Reserved.
# Maintainers: Dun Liang <randonlang@gmail.com>.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
@ -138,6 +139,10 @@ def setup_cuda_extern():
import traceback
line = traceback.format_exc()
LOG.w(f"CUDA found but {lib_name} is not loaded:\n{line}")
if lib_name == "cudnn":
LOG.w(f"Develop version of CUDNN not found, "
"please refer to CUDA offical tar file installation: "
"https://docs.nvidia.com/deeplearning/cudnn/install-guide/index.html#installlinux-tar")
def setup_cuda_lib(lib_name, link=True, extra_flags=""):
globals()[lib_name+"_ops"] = None
@ -183,7 +188,7 @@ def install_cutt(root_folder):
filename = "cutt-master.zip"
fullname = os.path.join(root_folder, filename)
dirname = os.path.join(root_folder, filename.replace(".zip",""))
true_md5 = "a6f4f7f75310a69b131e21f1ebec768a"
true_md5 = "af5bc35eea1832a42c0e0011659b7209"
if os.path.exists(fullname):
md5 = run_cmd('md5sum '+fullname).split()[0]
@ -205,7 +210,11 @@ def install_cutt(root_folder):
zf.close()
LOG.i("installing cutt...")
run_cmd(f"make", cwd=dirname)
arch_flag = ""
if len(flags.cuda_archs):
arch_flag = f" -arch=compute_{min(flags.cuda_archs)} "
arch_flag += ''.join(map(lambda x:f' -code=sm_{x} ', flags.cuda_archs))
run_cmd(f"make NVCC_GENCODE='{arch_flag}' nvcc_path='{nvcc_path}'", cwd=dirname)
return dirname
def setup_cutt():

View File

@ -1,5 +1,6 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
# Copyright (c) 2020 Jittor. All Rights Reserved.
# Maintainers: Dun Liang <randonlang@gmail.com>.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
@ -74,10 +75,17 @@ def compile(compiler, flags, inputs, output, combind_build=False):
for input, obj_file in zip(inputs, obj_files):
cc = compiler
nflags = oflags
if has_cuda and input.endswith(".cu"):
nflags = convert_nvcc_flags(oflags)
cc = nvcc_path
if input.endswith(".cu"):
if has_cuda:
nflags = convert_nvcc_flags(oflags)
cc = nvcc_path
else:
continue
cmd = f"{cc} {input} {nflags} -c {lto_flags} -o {obj_file}"
if "nan_checker" in input:
# nan checker needs to disable fast_math
cmd = cmd.replace("--use_fast_math", "")
cmd = cmd.replace("-Ofast", "-O2")
cmds.append(cmd)
jit_utils.run_cmds(cmds, cache_path, jittor_path, "Compiling "+base_output)
cmd = f"{compiler} {' '.join(obj_files)} {flags} {lto_flags} {link} -o {output}"
@ -894,6 +902,8 @@ make_cache_dir(cache_path)
make_cache_dir(os.path.join(cache_path, "jit"))
make_cache_dir(os.path.join(cache_path, "obj_files"))
make_cache_dir(os.path.join(cache_path, "gen"))
ck_path = os.path.join(cache_path, "checkpoints")
make_cache_dir(ck_path)
# build cache_compile
cc_flags += f" -I{jittor_path}/src "
@ -943,7 +953,8 @@ pyjt_gen_src = pyjt_compiler.compile(cache_path, jittor_path)
# 3. op_utils
# 4. other
files2 = pyjt_gen_src
files4 = run_cmd('find -L src | grep "cc$"', jittor_path).splitlines()
grep_args = '"c[cu]$"' if has_cuda else '"cc$"'
files4 = run_cmd('find -L src | grep '+grep_args, jittor_path).splitlines()
at_beginning = [
"src/ops/op_utils.cc",
"src/event_queue.cc",

View File

@ -1,9 +1,10 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Copyright (c) 2020 Jittor. All Rights Reserved.
# Maintainers:
# Guowei Yang <471184555@qq.com>
# Guoye Yang <498731903@qq.com>
# Dun Liang <randonlang@gmail.com>.
# All Rights Reserved.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
@ -12,6 +13,7 @@ import numpy as np
from jittor import pool
from collections.abc import Sequence
def argmax_pool(x, size, stride, padding=0):
return pool.pool(x, size, 'maximum', padding, stride)
@ -196,28 +198,37 @@ def setitem(x, slices, value):
mask = jt.broadcast(slices, x)
value = jt.broadcast(value, x)
return x.assign(mask.ternary(value, x))
if isinstance(slices, list):
slices = tuple(slices)
if isinstance(slices, Sequence):
ss = []
for s in slices:
if isinstance(s, jt.Var) and s.dtype == "bool":
ss.extend(s.where())
else:
ss.append(s)
slices = tuple(ss)
return x.assign(x.setitem(slices, value))
jt.Var.__getitem__ = jt.Var.slice_var = getitem
jt.Var.__setitem__ = setitem
def concat(arr, dim):
def concat(arr, dim=0):
'''Concat Operator can concat a list of jt Var at a specfic dimension.
* [in] x: input var list for concat
* [in] dim: concat which dim
* [out] out: concat result
* return: concat result
Example::
jt.concat([jt.array([[1],[2]]), jt.array([[2],[2]])], dim=1)
# return [[1],[2],[2],[2]]
'''
# TODO: low performance when concat lots of vars
if not isinstance(arr, Sequence):
raise TypeError("concat arr needs to be a tuple or list")
if len(arr) == 0:
raise ValueError("need at least one array to concat")
total_dim = 0
if dim < 0: dim += len(arr[0].shape)
for a in arr:

View File

@ -1,8 +1,9 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Copyright (c) 2020 Jittor. All Rights Reserved.
# Maintainers:
# Meng-Hao Guo <guomenghao1997@gmail.com>
# Dun Liang <randonlang@gmail.com>.
# All Rights Reserved.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
@ -27,8 +28,9 @@ mpi = jt.mpi
img_open_hook = HookTimer(Image, "open")
class Worker:
def __init__(self, target, args, buffer_size):
def __init__(self, target, args, buffer_size, keep_numpy_array=False):
self.buffer = jt.RingBuffer(buffer_size)
self.buffer.keep_numpy_array(keep_numpy_array)
self.status = mp.Array('f', 5, lock=False)
self.p = mp.Process(target=target, args=args+(self.buffer,self.status))
@ -67,7 +69,8 @@ class Dataset(object):
drop_last = False,
num_workers = 0,
buffer_size = 512*1024*1024,
stop_grad = True):
stop_grad = True,
keep_numpy_array = False):
super().__init__()
self.total_len = None
self.batch_size = batch_size
@ -76,6 +79,7 @@ class Dataset(object):
self.num_workers = num_workers
self.buffer_size = buffer_size
self.stop_grad = stop_grad
self.keep_numpy_array = keep_numpy_array
def __getitem__(self, index):
raise NotImplementedError
@ -101,7 +105,7 @@ class Dataset(object):
Attrs:
* batch_size(int): batch size, default 16.
* totol_len(int): totol lenght.
* total_len(int): total lenght.
* shuffle(bool): shuffle at each epoch, default False.
* drop_last(bool): if true, the last batch of dataset might smaller than batch_size, default True.
* num_workers: number of workers for loading data
@ -111,12 +115,15 @@ class Dataset(object):
for k,v in kw.items():
assert hasattr(self, k), k
setattr(self, k, v)
self.reset()
return self
def to_jittor(self, batch):
'''
Change batch data to jittor array, such as np.ndarray, int, and float.
'''
if self.keep_numpy_array: return batch
if isinstance(batch, jt.Var): return batch
to_jt = lambda x: jt.array(x).stop_grad() \
if self.stop_grad else jt.array(x)
if isinstance(batch, np.ndarray):
@ -129,7 +136,7 @@ class Dataset(object):
isinstance(a, float):
new_batch.append(to_jt(a))
else:
new_batch.append(a)
new_batch.append(self.to_jittor(a))
return new_batch
def collate_batch(self, batch):
@ -299,11 +306,26 @@ Example::
self.num_idle_c = mp.Condition(self.gid.get_lock())
for i in range(self.num_workers):
w = Worker(target=self._worker_main, args=(i,),
buffer_size=self.buffer_size)
buffer_size=self.buffer_size,
keep_numpy_array=self.keep_numpy_array)
workers.append(w)
self.workers = workers
self.index_list_numpy = np.ndarray(dtype='int32', shape=self.real_len, buffer=self.index_list)
def reset(self):
if not hasattr(self, "workers"):
return
self._stop_all_workers()
self.terminate()
del self.index_list
del self.idmap
del self.gid
del self.gidc
del self.num_idle
del self.num_idle_c
del self.workers
del self.index_list_numpy
def __del__(self):
if mp_log_v:
print("dataset deleted")

View File

@ -1,8 +1,9 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Copyright (c) 2020 Jittor. All Rights Reserved.
# Maintainers:
# Meng-Hao Guo <guomenghao1997@gmail.com>
# Dun Liang <randonlang@gmail.com>.
# All Rights Reserved.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************

View File

@ -1,9 +1,10 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Copyright (c) 2020 Jittor. All Rights Reserved.
# Maintainers:
# Guoye Yang <498731903@qq.com>
# Dun Liang <randonlang@gmail.com>.
#
# All Rights Reserved.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
@ -308,7 +309,7 @@ class DepthwiseConv(Function):
grid = dim3(ksize_width, ksize_height, output_channels);
threads = dim3(std::min(output_width, block_size), crop_output_height, 1);
cudaMemsetAsync(filter_grad_p, 0, filter_grad->size);
KernelDepthwiseConvFilterGrad<
input_type><<<grid, threads, 0>>>(

View File

@ -1,8 +1,9 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Copyright (c) 2020 Jittor. All Rights Reserved.
# Maintainers:
# Guowei Yang <471184555@qq.com>
# Dun Liang <randonlang@gmail.com>.
# All Rights Reserved.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************

View File

@ -1,8 +1,9 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Copyright (c) 2020 Jittor. All Rights Reserved.
# Maintainers:
# Guowei Yang <471184555@qq.com>
# Dun Liang <randonlang@gmail.com>.
# All Rights Reserved.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************

View File

@ -1,6 +1,6 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. All Rights Reserved.
# Authors:
# Maintainers:
# Haoyang Peng <2247838039@qq.com>
# Guowei Yang <471184555@qq.com>
# Dun Liang <randonlang@gmail.com>.

View File

@ -1,9 +1,10 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Copyright (c) 2020 Jittor. All Rights Reserved.
# Maintainers:
# Guowei Yang <471184555@qq.com>
# Dun Liang <randonlang@gmail.com>.
#
# All Rights Reserved.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************

View File

@ -1,9 +1,10 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Copyright (c) 2020 Jittor. All Rights Reserved.
# Maintainers:
# Dun Liang <randonlang@gmail.com>.
# Wenyang Zhou <576825820@qq.com>
#
# All Rights Reserved.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
@ -96,6 +97,27 @@ def repeat(x, *shape):
jt.Var.repeat = repeat
def repeat_interleave(x,repeats,dim=None):
# TODO repeats is jt.Var
assert isinstance(repeats,int)
if dim == None:
x = x.reshape(-1)
dim=0
if dim<0: dim+=x.ndim
tar_shape = list(x.shape)
x_shape = list(x.shape)
tar_shape[dim] = tar_shape[dim]*repeats
dims = []
for i in range(len(tar_shape)):
if dim==i:
dims.append(f"i{i}/{repeats}")
else:
dims.append(f"i{i}")
return x.reindex(tar_shape,dims)
jt.Var.repeat_interleave = repeat_interleave
def chunk(x, chunks, dim=0):
r'''
Splits a var into a specific number of chunks. Each chunk is a view of the input var.
@ -209,15 +231,18 @@ def flip(x, dim=0):
>>> x.flip(1)
[[4 3 2 1]]
'''
assert isinstance(dim, int)
if dim<0:
dim+=x.ndim
assert dim>=0 and dim<len(x.shape)
if isinstance(dim, int):
dim = [dim]
for i in range(len(dim)):
if dim[i]<0:
dim[i] += x.ndim
assert dim[i]>=0 and dim[i]<x.ndim
dim = set(dim)
tar_dims = []
for i in range(len(x.shape)):
if i == dim:
tar_dims.append(f"{x.shape[dim]-1}-i{i}")
if i in dim:
tar_dims.append(f"xshape{i}-1-i{i}")
else:
tar_dims.append(f"i{i}")
return x.reindex(x.shape, tar_dims)
@ -335,16 +360,37 @@ def unbind(x, dim=0):
jt.Var.unbind = unbind
def make_grid(x, nrow=8, padding=2, normalize=False, range=None, scale_each=False, pad_value=0):
assert range == None
assert isinstance(range, tuple) or range is None
assert scale_each == False
if isinstance(x, list): x = jt.stack(x)
if normalize: x = (x - x.min()) / (x.max() - x.min())
if normalize:
if range is None: x = (x - x.min()) / (x.max() - x.min())
else: x = (x - range[0]) / (range[1] - range[0])
b,c,h,w = x.shape
ncol = math.ceil(b / nrow)
return x.reindex([c, h*ncol+(ncol+1)*padding, w*nrow+(nrow+1)*padding],
[f"i1/{padding+h}*{nrow}+i2/{padding+w}", "i0",
f"i1-i1/{padding+h}*{padding+h}-{padding}", f"i2-i2/{padding+w}*{padding+w}-{padding}"], overflow_value=pad_value)
def save_image(
x,
filepath,
nrow: int = 8,
padding: int = 2,
normalize: bool = False,
range = None,
scale_each = False,
pad_value = 0,
format = None
):
from PIL import Image
grid = make_grid(x, nrow=nrow, padding=padding, pad_value=pad_value,
normalize=normalize, range=range, scale_each=scale_each)
ndarr = (grid*255+0.5).clamp(0, 255).permute(1, 2, 0).uint8().numpy()
im = Image.fromarray(ndarr)
im.save(filepath, format=format)
def _ntuple(n):
def parse(x):
@ -582,12 +628,11 @@ def gather(x,dim,index):
return x.reindex(ins)
jt.Var.gather = gather
def prod(x,dim=0):
def _prod(x,dim=0):
x = jt.log(x)
x = x.sum(dim=dim)
return jt.exp(x)
jt.Var.prod = prod
def cumsum_forward(np, data):
a = data['inputs'][0]

View File

@ -1,8 +1,9 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Copyright (c) 2020 Jittor. All Rights Reserved.
# Maintainers:
# Wenyang Zhou <576825820@qq.com>
# Dun Liang <randonlang@gmail.com>.
# All Rights Reserved.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
@ -61,6 +62,7 @@ class AlexNet(nn.Module):
x = self.classifier(x)
return x
def alexnet(**kwargs):
def alexnet(pretrained=False, **kwargs):
model = AlexNet(**kwargs)
if pretrained: model.load("jittorhub://alexnet.pkl")
return model

View File

@ -1,8 +1,9 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Copyright (c) 2020 Jittor. All Rights Reserved.
# Maintainers:
# Guowei Yang <471184555@qq.com>
# Dun Liang <randonlang@gmail.com>.
# All Rights Reserved.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
@ -21,7 +22,7 @@ def densenet121(pretrained=False, **kwargs):
pretrained (bool): If True, returns a model pre-trained on ImageNet
'''
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), **kwargs)
assert not pretrained, "pretrained doesn't support now"
if pretrained: model.load("jittorhub://densenet121.pkl")
return model
def densenet161(pretrained=False, **kwargs):
@ -32,7 +33,7 @@ def densenet161(pretrained=False, **kwargs):
pretrained (bool): If True, returns a model pre-trained on ImageNet
'''
model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), **kwargs)
assert not pretrained, "pretrained doesn't support now"
if pretrained: model.load("jittorhub://densenet161.pkl")
return model
def densenet169(pretrained=False, **kwargs):
@ -43,7 +44,7 @@ def densenet169(pretrained=False, **kwargs):
pretrained (bool): If True, returns a model pre-trained on ImageNet
'''
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), **kwargs)
assert not pretrained, "pretrained doesn't support now"
if pretrained: model.load("jittorhub://densenet169.pkl")
return model
def densenet201(pretrained=False, **kwargs):
@ -54,7 +55,7 @@ def densenet201(pretrained=False, **kwargs):
pretrained (bool): If True, returns a model pre-trained on ImageNet
'''
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), **kwargs)
assert not pretrained, "pretrained doesn't support now"
if pretrained: model.load("jittorhub://densenet201.pkl")
return model

View File

@ -1,8 +1,9 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Copyright (c) 2020 Jittor. All Rights Reserved.
# Maintainers:
# Wenyang Zhou <576825820@qq.com>
# Dun Liang <randonlang@gmail.com>.
# All Rights Reserved.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
@ -12,8 +13,10 @@ from jittor import nn
__all__ = ['GoogLeNet', 'googlenet']
def googlenet(**kwargs):
return GoogLeNet(**kwargs)
def googlenet(pretrained=False, **kwargs):
model = GoogLeNet(**kwargs)
if pretrained: model.load("jittorhub://googlenet.pkl")
return model
class GoogLeNet(nn.Module):
""" GoogLeNet model architecture.

View File

@ -4,7 +4,9 @@ from jittor import nn
__all__ = ['Inception3', 'inception_v3']
def inception_v3(pretrained=False, progress=True, **kwargs):
return Inception3(**kwargs)
model = Inception3(**kwargs)
if pretrained: model.load("jittorhub://inception_v3.pkl")
return model
class Inception3(nn.Module):
""" Inceptionv3 model architecture.

View File

@ -1,8 +1,9 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Copyright (c) 2020 Jittor. All Rights Reserved.
# Maintainers:
# Wenyang Zhou <576825820@qq.com>
# Dun Liang <randonlang@gmail.com>.
# All Rights Reserved.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
@ -90,18 +91,22 @@ class MNASNet(nn.Module):
x = x.mean([2, 3])
return self.classifier(x)
def mnasnet0_5(**kwargs):
def mnasnet0_5(pretrained=False, **kwargs):
model = MNASNet(0.5, **kwargs)
if pretrained: model.load("jittorhub://mnasnet0_5.pkl")
return model
def mnasnet0_75(**kwargs):
def mnasnet0_75(pretrained=False, **kwargs):
model = MNASNet(0.75, **kwargs)
if pretrained: model.load("jittorhub://mnasnet0_75.pkl")
return model
def mnasnet1_0(**kwargs):
def mnasnet1_0(pretrained=False, **kwargs):
model = MNASNet(1.0, **kwargs)
if pretrained: model.load("jittorhub://mnasnet1_0.pkl")
return model
def mnasnet1_3(**kwargs):
def mnasnet1_3(pretrained=False, **kwargs):
model = MNASNet(1.3, **kwargs)
if pretrained: model.load("jittorhub://mnasnet1_3.pkl")
return model

View File

@ -1,8 +1,9 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Copyright (c) 2020 Jittor. All Rights Reserved.
# Maintainers:
# Wenyang Zhou <576825820@qq.com>
# Dun Liang <randonlang@gmail.com>.
# All Rights Reserved.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
@ -93,7 +94,8 @@ class MobileNetV2(nn.Module):
def execute(self, x):
return self._forward_impl(x)
def mobilenet_v2():
def mobilenet_v2(pretrained=False):
model = MobileNetV2()
if pretrained: model.load("jittorhub://mobilenet_v2.pkl")
return model

View File

@ -1,9 +1,10 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Copyright (c) 2020 Jittor. All Rights Reserved.
# Maintainers:
# Guowei Yang <471184555@qq.com>
# Wenyang Zhou <576825820@qq.com>
# Dun Liang <randonlang@gmail.com>.
# All Rights Reserved.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
@ -154,16 +155,23 @@ def _resnet(block, layers, **kwargs):
model = ResNet(block, layers, **kwargs)
return model
def Resnet18(**kwargs):
return _resnet(BasicBlock, [2, 2, 2, 2], **kwargs)
def Resnet18(pretrained=False, **kwargs):
model = _resnet(BasicBlock, [2, 2, 2, 2], **kwargs)
if pretrained: model.load("jittorhub://resnet18.pkl")
return model
resnet18 = Resnet18
def Resnet34(**kwargs):
return _resnet( BasicBlock, [3, 4, 6, 3], **kwargs)
def Resnet34(pretrained=False, **kwargs):
model = _resnet(BasicBlock, [3, 4, 6, 3], **kwargs)
if pretrained: model.load("jittorhub://resnet34.pkl")
return model
resnet34 = Resnet34
def Resnet50(**kwargs):
return _resnet(Bottleneck, [3, 4, 6, 3], **kwargs)
def Resnet50(pretrained=False, **kwargs):
model = _resnet(Bottleneck, [3, 4, 6, 3], **kwargs)
if pretrained: model.load("jittorhub://resnet50.pkl")
return model
resnet50 = Resnet50
def Resnet38(**kwargs):
@ -174,7 +182,7 @@ def Resnet26(**kwargs):
return _resnet(Bottleneck, [1, 2, 4, 1], **kwargs)
resnet26 = Resnet26
def Resnet101(**kwargs):
def Resnet101(pretrained=False, **kwargs):
"""
ResNet-101 model architecture.
@ -188,28 +196,38 @@ def Resnet101(**kwargs):
return _resnet(Bottleneck, [3, 4, 23, 3], **kwargs)
resnet101 = Resnet101
def Resnet152(**kwargs):
return _resnet(Bottleneck, [3, 8, 36, 3], **kwargs)
def Resnet152(pretrained=False, **kwargs):
model = _resnet(Bottleneck, [3, 8, 36, 3], **kwargs)
if pretrained: model.load("jittorhub://resnet152.pkl")
return model
resnet152 = Resnet152
def Resnext50_32x4d(**kwargs):
def Resnext50_32x4d(pretrained=False, **kwargs):
kwargs['groups'] = 32
kwargs['width_per_group'] = 4
return _resnet(Bottleneck, [3, 4, 6, 3], **kwargs)
model = _resnet(Bottleneck, [3, 4, 6, 3], **kwargs)
if pretrained: model.load("jittorhub://resnext50_32x4d.pkl")
return model
resnext50_32x4d = Resnext50_32x4d
def Resnext101_32x8d(**kwargs):
def Resnext101_32x8d(pretrained=False, **kwargs):
kwargs['groups'] = 32
kwargs['width_per_group'] = 8
return _resnet(Bottleneck, [3, 4, 23, 3], **kwargs)
model = _resnet(Bottleneck, [3, 4, 23, 3], **kwargs)
if pretrained: model.load("jittorhub://resnext101_32x8d.pkl")
return model
resnext101_32x8d = Resnext101_32x8d
def Wide_resnet50_2(**kwargs):
def Wide_resnet50_2(pretrained=False, **kwargs):
kwargs['width_per_group'] = (64 * 2)
return _resnet(Bottleneck, [3, 4, 6, 3], **kwargs)
model = _resnet(Bottleneck, [3, 4, 6, 3], **kwargs)
if pretrained: model.load("jittorhub://wide_resnet50_2.pkl")
return model
wide_resnet50_2 = Wide_resnet50_2
def Wide_resnet101_2(**kwargs):
def Wide_resnet101_2(pretrained=False, **kwargs):
kwargs['width_per_group'] = (64 * 2)
return _resnet(Bottleneck, [3, 4, 23, 3], **kwargs)
model = _resnet(Bottleneck, [3, 4, 23, 3], **kwargs)
if pretrained: model.load("jittorhub://wide_resnet101_2.pkl")
return model
wide_resnet101_2 = Wide_resnet101_2

View File

@ -1,9 +1,10 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Copyright (c) 2020 Jittor. All Rights Reserved.
# Maintainers:
# Wenyang Zhou <576825820@qq.com>
# Dun Liang <randonlang@gmail.com>.
# All Rights Reserved.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
@ -93,14 +94,22 @@ def _shufflenetv2(arch, *args):
model = ShuffleNetV2(*args)
return model
def shufflenet_v2_x0_5():
return _shufflenetv2('shufflenetv2_x0.5', [4, 8, 4], [24, 48, 96, 192, 1024])
def shufflenet_v2_x0_5(pretrained=False):
model = _shufflenetv2('shufflenetv2_x0.5', [4, 8, 4], [24, 48, 96, 192, 1024])
if pretrained: model.load("jittorhub://shufflenet_v2_x0_5.pkl")
return model
def shufflenet_v2_x1_0():
return _shufflenetv2('shufflenetv2_x1.0', [4, 8, 4], [24, 116, 232, 464, 1024])
def shufflenet_v2_x1_0(pretrained=False):
model = _shufflenetv2('shufflenetv2_x1.0', [4, 8, 4], [24, 116, 232, 464, 1024])
if pretrained: model.load("jittorhub://shufflenet_v2_x1_0.pkl")
return model
def shufflenet_v2_x1_5():
return _shufflenetv2('shufflenetv2_x1.5', [4, 8, 4], [24, 176, 352, 704, 1024])
def shufflenet_v2_x1_5(pretrained=False):
model = _shufflenetv2('shufflenetv2_x1.5', [4, 8, 4], [24, 176, 352, 704, 1024])
if pretrained: model.load("jittorhub://shufflenet_v2_x1_5.pkl")
return model
def shufflenet_v2_x2_0():
return _shufflenetv2('shufflenetv2_x2.0', [4, 8, 4], [24, 244, 488, 976, 2048])
def shufflenet_v2_x2_0(pretrained=False):
model = _shufflenetv2('shufflenetv2_x2.0', [4, 8, 4], [24, 244, 488, 976, 2048])
if pretrained: model.load("jittorhub://shufflenet_v2_x2_0.pkl")
return model

View File

@ -1,8 +1,9 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Copyright (c) 2020 Jittor. All Rights Reserved.
# Maintainers:
# Wenyang Zhou <576825820@qq.com>
# Dun Liang <randonlang@gmail.com>.
# All Rights Reserved.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
@ -83,8 +84,12 @@ def _squeezenet(version, **kwargs):
model = SqueezeNet(version, **kwargs)
return model
def squeezenet1_0(**kwargs):
return _squeezenet('1_0', **kwargs)
def squeezenet1_0(pretrained=False, **kwargs):
model = _squeezenet('1_0', **kwargs)
if pretrained: model.load("jittorhub://squeezenet1_0.pkl")
return model
def squeezenet1_1(**kwargs):
return _squeezenet('1_1', **kwargs)
def squeezenet1_1(pretrained=False, **kwargs):
model = _squeezenet('1_1', **kwargs)
if pretrained: model.load("jittorhub://squeezenet1_1.pkl")
return model

View File

@ -1,8 +1,9 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Copyright (c) 2020 Jittor. All Rights Reserved.
# Maintainers:
# Guoye Yang <498731903@qq.com>
# Dun Liang <randonlang@gmail.com>.
# All Rights Reserved.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
@ -67,33 +68,49 @@ def _vgg(arch, cfg, batch_norm, **kwargs):
return model
def vgg11(**kwargs):
return _vgg('vgg11', 'A', False, **kwargs)
def vgg11(pretrained=False, **kwargs):
model = _vgg('vgg11', 'A', False, **kwargs)
if pretrained: model.load("jittorhub://vgg11.pkl")
return model
def vgg11_bn(**kwargs):
return _vgg('vgg11_bn', 'A', True, **kwargs)
def vgg11_bn(pretrained=False, **kwargs):
model = _vgg('vgg11_bn', 'A', True, **kwargs)
if pretrained: model.load("jittorhub://vgg11_bn.pkl")
return model
def vgg13(**kwargs):
return _vgg('vgg13', 'B', False, **kwargs)
def vgg13(pretrained=False, **kwargs):
model = _vgg('vgg13', 'B', False, **kwargs)
if pretrained: model.load("jittorhub://vgg13.pkl")
return model
def vgg13_bn(**kwargs):
return _vgg('vgg13_bn', 'B', True, **kwargs)
def vgg13_bn(pretrained=False, **kwargs):
model = _vgg('vgg13_bn', 'B', True, **kwargs)
if pretrained: model.load("jittorhub://vgg13_bn.pkl")
return model
def vgg16(**kwargs):
return _vgg('vgg16', 'D', False, **kwargs)
def vgg16(pretrained=False, **kwargs):
model = _vgg('vgg16', 'D', False, **kwargs)
if pretrained: model.load("jittorhub://vgg16.pkl")
return model
def vgg16_bn(**kwargs):
return _vgg('vgg16_bn', 'D', True, **kwargs)
def vgg16_bn(pretrained=False, **kwargs):
model = _vgg('vgg16_bn', 'D', True, **kwargs)
if pretrained: model.load("jittorhub://vgg16_bn.pkl")
return model
def vgg19(**kwargs):
return _vgg('vgg19', 'E', False, **kwargs)
def vgg19(pretrained=False, **kwargs):
model = _vgg('vgg19', 'E', False, **kwargs)
if pretrained: model.load("jittorhub://vgg19.pkl")
return model
def vgg19_bn(**kwargs):
return _vgg('vgg19_bn', 'E', True, **kwargs)
def vgg19_bn(pretrained=False, **kwargs):
model = _vgg('vgg19_bn', 'E', True, **kwargs)
if pretrained: model.load("jittorhub://vgg19_bn.pkl")
return model

View File

@ -1,12 +1,13 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Copyright (c) 2020 Jittor. All Rights Reserved.
# Maintainers:
# Guowei Yang <471184555@qq.com>
# Guoye Yang <498731903@qq.com>
# Wenyang Zhou <576825820@qq.com>
# Meng-Hao Guo <guomenghao1997@gmail.com>
# Dun Liang <randonlang@gmail.com>.
#
# All Rights Reserved.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
@ -16,7 +17,7 @@ import numpy as np
import collections
import math
from collections import OrderedDict
from jittor.pool import Pool, pool, AdaptiveAvgPool2d
from jittor.pool import *
from jittor.optim import *
from jittor.misc import _pair
@ -154,6 +155,7 @@ def get_init_var_rand(shape, dtype):
def relu(x): return jt.ternary((x>0.0), x, jt.broadcast_var(0.0, x))
def leaky_relu(x, scale=0.01): return jt.ternary(x>0, x, x*scale)
def relu6(x): return jt.minimum(jt.maximum(x, 0.0), 6.0)
def elu(x,alpha=1.0):return jt.ternary(x>0,x,alpha*(x.exp()-1))
def sign(x):
one = jt.ones(x.shape)
x = jt.ternary(x>0, one, x)
@ -165,6 +167,13 @@ def gelu(x):
r = erf*x*.5
return r
class ELU(Module):
def __init__(self,alpha=1.0):
self.alpha=alpha
def execute(self,x):
return elu(x,self.alpha)
class PReLU(Module):
def __init__(self, num_parameters=1, init_=0.25):
self.num_parameters = num_parameters
@ -238,6 +247,30 @@ def smooth_l1_loss(y_true, y_pred,reduction="mean"):
else:
raise ValueError(f'not support {reduction}')
def nll_loss(output,target,weight=None,ignore_index=-100,reduction='mean'):
assert output.ndim<=2 and output.ndim>0 and target.ndim==1
n_classes = output.shape[-1]
assert weight is None or weight.numel()==n_classes
assert ignore_index<0 or ignore_index<n_classes
if weight is None:
weight = jt.ones((n_classes,))
if ignore_index>0:
weight[ignore_index]=0
if output.ndim==2:
index = jt.index((output.shape[0],),dim=0)
loss = -output[index,target]*weight[target]
else:
loss = -output[target[0]]*weight[target[0]]
if reduction=="mean":
total_weight = weight[target].sum() if output.ndim==2 else weight[target[0]].sum()
return loss.sum()/total_weight
elif reduction=="sum":
return loss.sum()
elif reduction=="none":
return loss
else:
raise ValueError(f'not support {reduction}')
class CrossEntropyLoss(Module):
def __init__(self,ignore_index=None):
self.ignore_index = ignore_index
@ -330,6 +363,9 @@ class Dropout(Module):
output = output * noise / (1.0 - self.p) # div keep prob
return output
def dropout(x,p=0.5,is_train=False):
return Dropout(p=p,is_train=is_train)(x)
class Linear(Module):
def __init__(self, in_features, out_features, bias=True):
self.in_features = in_features
@ -707,6 +743,45 @@ class ConvTranspose(Module):
y = y + b
return y
def conv_transpose(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
x = input
N,C,H,W = x.shape
i,o,h,w = weight.shape
assert C==i
assert groups==1, "Group conv not supported yet."
stride = stride if isinstance(stride, tuple) else (stride, stride)
dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation)
# added
padding = padding if isinstance(padding, tuple) else (padding, padding)
output_padding = output_padding if isinstance (output_padding, tuple) else (output_padding, output_padding)
assert output_padding[0] < max(stride[0], dilation[0]) and \
output_padding[1] < max(stride[1], dilation[1]), \
"output padding must be smaller than max(stride, dilation)"
stride_h, stride_w = stride
padding_h, padding_w = padding
dilation_h, dilation_w = dilation
h_out = (H-1) * stride_h + output_padding[0] - 2*padding_h + 1 + (h-1)*dilation_h
w_out = (W-1) * stride_w + output_padding[1] - 2*padding_w + 1 + (w-1)*dilation_w
out_shape = (N, o, h_out, w_out)
shape = (N, i, o, H, W, h, w)
xx = x.broadcast(shape, (2, 5, 6)) # i,h,w
ww = weight.broadcast(shape, (0, 3, 4)) # N,H,W
y = (ww*xx).reindex_reduce("add", out_shape, [
'i0', # N
'i2', # o
f'i3*{stride_h}-{padding_h}+i5*{dilation_h}', # Hid+Khid
f'i4*{stride_w}-{padding_w}+i6*{dilation_w}', # Wid+KWid
])
if isinstance(bias, jt.Var):
b = bias.broadcast(y.shape, [0,2,3])
y = y + b
else:
assert not bias, "Bias should be none or jittor var"
return y
conv_transpose2d = conv_transpose
def pad(x,padding, mode='constant', value=0):
assert mode in ['constant','replicate','reflect','circular'],'only support constant,replicate,reflect,circular pad'
@ -865,8 +940,9 @@ class Sigmoid(Module):
def execute(self, x) :
return x.sigmoid()
def softplus(x,beta=1,threshold=20):
return 1 / beta * jt.log(1 + (beta * x).exp())
def softplus(x,beta=1.0,threshold=20.0):
return 1 / beta * jt.log(1 + (beta * x).minimum(threshold).exp()) + \
(x - threshold/beta).maximum(0.0)
def hardtanh(x,min_val=-1,max_val=1):
return jt.clamp(x,min_v=min_val,max_v=max_val)
@ -887,7 +963,7 @@ class Softplus(Module):
self.threshold = threshold
def execute(self, x):
return 1 / self.beta * jt.log(1 + (self.beta * x).exp())
return softplus(x, self.beta, self.threshold)
class Resize(Module):
def __init__(self, size, mode="nearest", align_corners=False):

View File

@ -1,12 +1,13 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Copyright (c) 2020 Jittor. All Rights Reserved.
# Maintainers:
# Guowei Yang <471184555@qq.com>
# Guoye Yang <498731903@qq.com>
# Wenyang Zhou <576825820@qq.com>
# Meng-Hao Guo <guomenghao1997@gmail.com>
# Dun Liang <randonlang@gmail.com>.
#
# All Rights Reserved.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
@ -33,11 +34,15 @@ class Optimizer(object):
assert isinstance(pg, dict)
self.param_groups.append(pg)
self.n_step = 0
def add_param_group(self, group):
self.param_groups.append(group)
@property
def defaults(self):
exclude = set(("defaults", "param_groups", "n_step"))
return { k:v for k, v in self.__dict__.items() if k[0] != '_' and k not in exclude }
exclude = set(("defaults", "param_groups", "n_step", "pre_step", "step"))
return { k:v for k, v in self.__dict__.items()
if k[0] != '_' and k not in exclude and not callable(v) }
def pre_step(self, loss):
""" something should be done before step, such as calc gradients, mpi sync, and so on.
@ -115,6 +120,12 @@ class SGD(Optimizer):
for p in pg["params"]:
values.append(jt.zeros(p.shape, p.dtype).stop_grad())
def add_param_group(self, group):
values = group["values"] = []
for p in group["params"]:
values.append(jt.zeros(p.shape, p.dtype).stop_grad())
self.param_groups.append(group)
def step(self, loss=None):
if loss is not None:
self.pre_step(loss)
@ -159,6 +170,12 @@ class RMSprop(Optimizer):
for p in pg["params"]:
values.append(jt.zeros(p.shape, p.dtype).stop_grad())
def add_param_group(self, group):
values = group["values"] = []
for p in group["params"]:
values.append(jt.zeros(p.shape, p.dtype).stop_grad())
self.param_groups.append(group)
def step(self, loss=None):
if loss is not None:
self.pre_step(loss)
@ -195,6 +212,14 @@ class Adam(Optimizer):
values.append(jt.zeros(p.shape, p.dtype).stop_grad())
m.append(jt.zeros(p.shape, p.dtype).stop_grad())
def add_param_group(self, group):
values = group["values"] = []
m = group["m"] = []
for p in group["params"]:
values.append(jt.zeros(p.shape, p.dtype).stop_grad())
m.append(jt.zeros(p.shape, p.dtype).stop_grad())
self.param_groups.append(group)
def step(self, loss=None):
if loss is not None:
self.pre_step(loss)

View File

@ -1,11 +1,12 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Copyright (c) 2020 Jittor. All Rights Reserved.
# Maintainers:
# Guowei Yang <471184555@qq.com>
# Wenyang Zhou <576825820@qq.com>
# Meng-Hao Guo <guomenghao1997@gmail.com>
# Dun Liang <randonlang@gmail.com>.
#
# All Rights Reserved.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
@ -30,11 +31,14 @@ class Pool(Module):
if self.ceil_mode == False:
h = (H+self.padding*2-self.kernel_size)//self.stride+1
w = (W+self.padding*2-self.kernel_size)//self.stride+1
use_code_op = self.op in ['maximum', 'minimum']
# some second order avg_pool is require, so we don't use code op here
else:
h = (H+self.padding*2-self.kernel_size + self.stride - 1)//self.stride+1
w = (W+self.padding*2-self.kernel_size + self.stride - 1)//self.stride+1
use_code_op = self.op in ['maximum', 'minimum', 'mean']
if self.op in ['maximum', 'minimum', 'mean']:
if use_code_op:
if self.op == 'mean':
if self.count_include_pad:
count = f"int count = {self.kernel_size*self.kernel_size};"
@ -187,5 +191,25 @@ class AdaptiveAvgPool2d(Module):
])
return xx.reduce("mean", [4,5])
def pool(x, kernel_size, op, padding=0, stride = 1):
return Pool(kernel_size, stride, padding, op=op)(x)
def pool(x, kernel_size, op, padding=0, stride=None):
return Pool(kernel_size, stride, padding, op=op)(x)
class AvgPool2d(Module):
def __init__(self, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True):
self.layer = Pool(kernel_size=kernel_size, stride=stride, padding=padding, ceil_mode=ceil_mode, count_include_pad=count_include_pad, op="mean")
def execute(self, x):
return self.layer(x)
def avg_pool2d(x, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True):
return AvgPool2d(kernel_size, stride, padding, ceil_mode, count_include_pad)(x)
class MaxPool2d(Module):
def __init__(self, kernel_size, stride=None, padding=0, dilation=None, return_indices=None, ceil_mode=False):
self.layer = Pool(kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, return_indices=return_indices, ceil_mode=ceil_mode, op="maximum")
def execute(self, x):
return self.layer(x)
def max_pool2d(x, kernel_size, stride=None, padding=0, dilation=None, return_indices=None, ceil_mode=False):
return MaxPool2d(kernel_size, stride, padding, dilation, return_indices, ceil_mode)(x)

View File

@ -1,5 +1,6 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
# Copyright (c) 2020 Jittor. All Rights Reserved.
# Maintainers: Dun Liang <randonlang@gmail.com>.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************

54
python/jittor/sparse.py Normal file
View File

@ -0,0 +1,54 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. All Rights Reserved.
# Maintainers:
# Dun Liang <randonlang@gmail.com>.
# Xiangli Li <190569238@qq.com>
#
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import jittor as jt
import numpy as np
class SparseVar:
def __init__(self,indices,values,shape):
assert isinstance(indices,jt.Var) and isinstance(values,jt.Var) and isinstance(shape,jt.NanoVector)
self.indices = indices
self.values = values
self.shape = shape
self.ndim = len(shape)
def _indices(self):
return self.indices
def _values(self):
return self.values
def t(self):
indices = list(self.indices.split(1,dim=0))
indices[-1],indices[-2] = indices[-2],indices[-1]
indices = jt.contrib.concat(indices,dim=0)
shape = list(self.shape)
shape[-1],shape[-2] = shape[-2],shape[-1]
shape = jt.NanoVector(shape)
return SparseVar(indices,self.values,shape)
def to_dense(self):
ret = jt.zeros(self.shape,self.values.dtype)
indices = tuple(self.indices.split(1,dim=0))
ret[indices]=self.values
return ret
def sparse_array(indices,values,shape):
return SparseVar(indices,values,shape)
def spmm(spase_x,y):
assert isinstance(spase_x,SparseVar) and isinstance(y,jt.Var)
assert spase_x.ndim==2 and y.ndim==2 and spase_x.shape[-1]==y.shape[0]
# TODO
x = spase_x.to_dense()
return jt.matmul(x,y)

View File

@ -1,5 +1,6 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
# Copyright (c) 2020 Jittor. All Rights Reserved.
# Maintainers: Dun Liang <randonlang@gmail.com>.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************

View File

@ -28,7 +28,7 @@ RUN pip3 install jittor --timeout 100 && python3.7 -m jittor.test.test_example
RUN pip3 uninstall jittor -y
COPY . jittor
RUN python3.7 -m pip install jittor
RUN python3.7 -m pip install ./jittor
RUN python3.7 -m jittor.test.test_core
EOF

View File

@ -28,7 +28,7 @@ RUN pip3 install jittor --timeout 100 && python3.7 -m jittor.test.test_example
RUN pip3 uninstall jittor -y
COPY . jittor
RUN python3.7 -m pip install jittor
RUN python3.7 -m pip install ./jittor
RUN python3.7 -m jittor.test.test_core
EOF

View File

@ -28,7 +28,7 @@ RUN pip3 install jittor --timeout 100 && python3.7 -m jittor.test.test_example
RUN pip3 uninstall jittor -y
COPY . jittor
RUN python3.7 -m pip install jittor
RUN python3.7 -m pip install ./jittor
RUN python3.7 -m jittor.test.test_core
EOF

View File

@ -28,7 +28,7 @@ RUN pip3 install jittor --timeout 100 && python3.7 -m jittor.test.test_example
RUN pip3 uninstall jittor -y
COPY . jittor
RUN python3.7 -m pip install jittor
RUN python3.7 -m pip install ./jittor
RUN python3.7 -m jittor.test.test_core
EOF

View File

@ -26,7 +26,7 @@ RUN pip3 install jittor --timeout 100 && python3 -m jittor.test.test_example
RUN pip3 uninstall jittor -y
COPY . jittor
RUN python3 -m pip install jittor
RUN python3 -m pip install ./jittor
RUN python3 -m jittor.test.test_core
EOF

View File

@ -28,7 +28,7 @@ RUN pip3 install jittor --timeout 100 && python3.7 -m jittor.test.test_example
RUN pip3 uninstall jittor -y
COPY . jittor
RUN python3.7 -m pip install jittor
RUN python3.7 -m pip install ./jittor
RUN python3.7 -m jittor.test.test_core
EOF

View File

@ -1,5 +1,6 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
// Copyright (c) 2020 Jittor. All Rights Reserved.
// Maintainers: Dun Liang <randonlang@gmail.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************

View File

@ -1,5 +1,6 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
# Copyright (c) 2020 Jittor. All Rights Reserved.
# Maintainers: Dun Liang <randonlang@gmail.com>.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************

View File

@ -1,5 +1,6 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
# Copyright (c) 2020 Jittor. All Rights Reserved.
# Maintainers: Dun Liang <randonlang@gmail.com>.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************

View File

@ -1,8 +1,9 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Copyright (c) 2020 Jittor. All Rights Reserved.
# Maintainers:
# Guoye Yang <498731903@qq.com>
# Dun Liang <randonlang@gmail.com>.
# All Rights Reserved.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************

View File

@ -5,7 +5,9 @@
# ***************************************************************
import unittest
import jittor as jt
from jittor.nn import Pool, pool
from jittor.nn import Pool, pool, AvgPool2d, avg_pool2d
from jittor.nn import MaxPool2d as j_MaxPool2d
from jittor.nn import max_pool2d as j_max_pool2d
import numpy as np
from .test_core import expect_error
from .test_grad import ngrad
@ -101,7 +103,7 @@ class TestArgPoolOp(unittest.TestCase):
check(jt_model, torch_model, shape, False)
for i in range(10):
check(jt_model, torch_model, [1,1,300,300], True)
def test_cpu_(self):
# x = jt.random([32, 128, 157, 300])
x = jt.random([4, 128, 157, 300])
@ -138,5 +140,50 @@ class TestArgPoolOp(unittest.TestCase):
shape = (2, 16, 33, 33)
check(jt_model, torch_model, shape, False)
def test_AvgPool2d(self):
from torch.nn import AvgPool2d as t_AvgPool2d
jt_model = AvgPool2d(3, 1, 1, ceil_mode=True)
torch_model = t_AvgPool2d(3, 1, 1, ceil_mode=True)
shape = (2, 16, 33, 33)
check(jt_model, torch_model, shape, False)
jt_model = AvgPool2d(3, 1, 1, ceil_mode=True, count_include_pad=False)
torch_model = t_AvgPool2d(3, 1, 1, ceil_mode=True, count_include_pad=False)
shape = (2, 16, 100, 100)
check(jt_model, torch_model, shape, False)
print('finish')
def test_avg_pool2d(self):
from torch.nn.functional import avg_pool2d as t_avg_pool2d
arr = np.random.random((2, 16, 33, 33))
jt_model = avg_pool2d(jt.array(arr), 3, 1, 1, ceil_mode=True)
torch_model = t_avg_pool2d(torch.Tensor(arr), 3, 1, 1, ceil_mode=True)
assert np.allclose(jt_model.numpy(), torch_model.numpy())
jt_model = avg_pool2d(jt.array(arr), 3, 1, 1, ceil_mode=True, count_include_pad=False)
torch_model = t_avg_pool2d(torch.Tensor(arr), 3, 1, 1, ceil_mode=True, count_include_pad=False)
assert np.allclose(jt_model.numpy(), torch_model.numpy())
print('finish')
def test_MaxPool2d(self):
from torch.nn import MaxPool2d
jt_model = j_MaxPool2d(3, 1, 1, ceil_mode=True)
torch_model = MaxPool2d(3, 1, 1, ceil_mode=True)
shape = (2, 16, 33, 33)
check(jt_model, torch_model, shape, False)
print('finish')
def test_max_pool2d(self):
from torch.nn.functional import max_pool2d
arr = np.random.random((2, 16, 33, 33))
jt_model = j_max_pool2d(jt.array(arr), 3, 1, 1, ceil_mode=True)
torch_model = max_pool2d(torch.Tensor(arr), 3, 1, 1, ceil_mode=True)
assert np.allclose(jt_model.numpy(), torch_model.numpy())
jt_model = j_max_pool2d(jt.array(arr), 3, 1, 1)
torch_model = max_pool2d(torch.Tensor(arr), 3, 1, 1)
assert np.allclose(jt_model.numpy(), torch_model.numpy())
print('finish')
if __name__ == "__main__":
unittest.main()

View File

@ -1,8 +1,9 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Copyright (c) 2020 Jittor. All Rights Reserved.
# Maintainers:
# Guoye Yang <498731903@qq.com>
# Dun Liang <randonlang@gmail.com>.
# All Rights Reserved.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************

View File

@ -1,8 +1,9 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Copyright (c) 2020 Jittor. All Rights Reserved.
# Maintainers:
# Guoye Yang <498731903@qq.com>
# Dun Liang <randonlang@gmail.com>.
# All Rights Reserved.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************

View File

@ -1,5 +1,6 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
# Copyright (c) 2020 Jittor. All Rights Reserved.
# Maintainers: Dun Liang <randonlang@gmail.com>.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************

View File

@ -1,8 +1,9 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Copyright (c) 2020 Jittor. All Rights Reserved.
# Maintainers:
# Guowei Yang <471184555@qq.com>
# Dun Liang <randonlang@gmail.com>.
# All Rights Reserved.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************

View File

@ -1,8 +1,9 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Copyright (c) 2020 Jittor. All Rights Reserved.
# Maintainers:
# Guowei Yang <471184555@qq.com>
# Dun Liang <randonlang@gmail.com>.
# All Rights Reserved.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************

View File

@ -1,9 +1,10 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Copyright (c) 2020 Jittor. All Rights Reserved.
# Maintainers:
# Guowei Yang <471184555@qq.com>
# Dun Liang <randonlang@gmail.com>.
# All Rights Reserved.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************

View File

@ -1,6 +1,6 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. All Rights Reserved.
# Authors: Dun Liang <randonlang@gmail.com>.
# Maintainers: Dun Liang <randonlang@gmail.com>.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************

View File

@ -1,9 +1,10 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Copyright (c) 2020 Jittor. All Rights Reserved.
# Maintainers:
# Wenyang Zhou <576825820@qq.com>
# Dun Liang <randonlang@gmail.com>.
# All Rights Reserved.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************

View File

@ -1,5 +1,6 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
# Copyright (c) 2020 Jittor. All Rights Reserved.
# Maintainers: Dun Liang <randonlang@gmail.com>.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************

View File

@ -1,9 +1,10 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Copyright (c) 2020 Jittor. All Rights Reserved.
# Maintainers:
# Meng-Hao Guo <guomenghao1997@gmail.com>
# Dun Liang <randonlang@gmail.com>.
#
# All Rights Reserved.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************

Some files were not shown because too many files have changed in this diff Show More