mirror of https://github.com/Jittor/Jittor
merge
This commit is contained in:
commit
1be344526c
31
README.cn.md
31
README.cn.md
|
@ -119,7 +119,9 @@ Jittor 一共提供三种方式安装: pip安装, 一键脚本安装 和 手动
|
|||
|
||||
```bash
|
||||
sudo apt install python3.7-dev libomp-dev
|
||||
sudo python3.7 -m pip install git+https://github.com/Jittor/jittor.git
|
||||
python3.7 -m pip install jittor
|
||||
# or install from github(latest version)
|
||||
# python3.7 -m pip install git+https://github.com/Jittor/jittor.git
|
||||
python3.7 -m jittor.test.test_example
|
||||
```
|
||||
|
||||
|
@ -217,13 +219,13 @@ jt.flags.use_cuda = 1
|
|||
```
|
||||
|
||||
|
||||
### 可选步骤五:进行完整测试
|
||||
### 可选步骤五:测试训练Resnet18
|
||||
|
||||
|
||||
要检查Jittor的完整性,您可以运行完整的测试。
|
||||
要检查Jittor的完整性,您可以运行Resnet18训练测试。
|
||||
|
||||
```bash
|
||||
python3.7 -m jittor.test -v
|
||||
python3.7 -m jittor.test.test_resnet
|
||||
```
|
||||
|
||||
如果这些测试失败,请为我们报告错误,我们十分欢迎您为Jittor做出贡献^ _ ^
|
||||
|
@ -360,10 +362,29 @@ Jittor还很年轻。 它可能存在错误和问题。 请在我们的错误跟
|
|||
|
||||
|
||||
|
||||
QQ 群:761222083
|
||||
|
||||
|
||||
|
||||
## 团队
|
||||
|
||||
|
||||
Jittor目前由来自[清华大学计算机图形学组](https://cg.cs.tsinghua.edu.cn/)的梁盾,杨国烨,杨国炜,周文洋和国孟昊等博士生维护。 如果您也对Jittor感兴趣并希望对其进行改进,请加入我们!
|
||||
Jittor目前由[清华大学计算机图形学组](https://cg.cs.tsinghua.edu.cn/)维护。 如果您也对Jittor感兴趣并希望对其进行改进,请加入我们!
|
||||
|
||||
|
||||
## 引用
|
||||
|
||||
```
|
||||
@article{hu2020jittor,
|
||||
title={Jittor: a novel deep learning framework with meta-operators and unified graph execution},
|
||||
author={Hu, Shi-Min and Liang, Dun and Yang, Guo-Ye and Yang, Guo-Wei and Zhou, Wen-Yang},
|
||||
journal={Information Sciences},
|
||||
volume={63},
|
||||
number={222103},
|
||||
pages={1--222103},
|
||||
year={2020}
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
## 版权声明
|
||||
|
|
32
README.md
32
README.md
|
@ -116,7 +116,9 @@ Jittor offers three ways to install: pip, script or manual.
|
|||
|
||||
```bash
|
||||
sudo apt install python3.7-dev libomp-dev
|
||||
sudo python3.7 -m pip install git+https://github.com/Jittor/jittor.git
|
||||
python3.7 -m pip install jittor
|
||||
# or install from github(latest version)
|
||||
# python3.7 -m pip install git+https://github.com/Jittor/jittor.git
|
||||
python3.7 -m jittor.test.test_example
|
||||
```
|
||||
|
||||
|
@ -210,14 +212,14 @@ import jittor as jt
|
|||
jt.flags.use_cuda = 1
|
||||
```
|
||||
|
||||
### Optional Step 5: Run full tests
|
||||
### Optional Step 5: Test Resnet18 training
|
||||
|
||||
|
||||
To check the integrity of Jittor, you can run full tests.
|
||||
To check the integrity of Jittor, you can run Resnet18 training test.
|
||||
|
||||
|
||||
```bash
|
||||
python3.7 -m jittor.test -v
|
||||
python3.7 -m jittor.test.test_resnet
|
||||
```
|
||||
if those tests are failed, please report bugs for us, and feel free to contribute ^_^
|
||||
|
||||
|
@ -353,12 +355,32 @@ Email: jittor@qq.com
|
|||
|
||||
File an issue: https://github.com/Jittor/jittor/issues
|
||||
|
||||
QQ Group: 761222083
|
||||
|
||||
|
||||
<img src="https://cg.cs.tsinghua.edu.cn/jittor/images/news/2020-12-8-21-19-1_2_2/fig4.png" width="200"/>
|
||||
|
||||
## The Team
|
||||
|
||||
|
||||
Jittor is currently maintained by Dun Liang, Guo-Ye Yang, Guo-Wei Yang, Wen-Yang Zhou and Meng-Hao Guo etc. from the [Tsinghua CSCG Group](https://cg.cs.tsinghua.edu.cn/). If you are also interested in Jittor and want to improve it, Please join us!
|
||||
Jittor is currently maintained by the [Tsinghua CSCG Group](https://cg.cs.tsinghua.edu.cn/). If you are also interested in Jittor and want to improve it, Please join us!
|
||||
|
||||
|
||||
## Citation
|
||||
|
||||
|
||||
```
|
||||
@article{hu2020jittor,
|
||||
title={Jittor: a novel deep learning framework with meta-operators and unified graph execution},
|
||||
author={Hu, Shi-Min and Liang, Dun and Yang, Guo-Ye and Yang, Guo-Wei and Zhou, Wen-Yang},
|
||||
journal={Information Sciences},
|
||||
volume={63},
|
||||
number={222103},
|
||||
pages={1--222103},
|
||||
year={2020}
|
||||
}
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
|
||||
|
|
|
@ -151,7 +151,9 @@ Jittor 一共提供三种方式安装: pip安装, 一键脚本安装 和 手动
|
|||
|
||||
```bash
|
||||
sudo apt install python3.7-dev libomp-dev
|
||||
sudo python3.7 -m pip install git+https://github.com/Jittor/jittor.git
|
||||
python3.7 -m pip install jittor
|
||||
# or install from github(latest version)
|
||||
# python3.7 -m pip install git+https://github.com/Jittor/jittor.git
|
||||
python3.7 -m jittor.test.test_example
|
||||
```
|
||||
|
||||
|
@ -453,13 +455,35 @@ Email: jittor@qq.com
|
|||
|
||||
File an issue: https://github.com/Jittor/jittor/issues
|
||||
|
||||
QQ Group: 761222083
|
||||
|
||||
QQ 群:761222083
|
||||
|
||||
<img src="https://cg.cs.tsinghua.edu.cn/jittor/images/news/2020-12-8-21-19-1_2_2/fig4.png" width="200"/>
|
||||
|
||||
## The Team
|
||||
|
||||
## 团队
|
||||
|
||||
Jittor is currently maintained by Dun Liang, Guo-Ye Yang, Guo-Wei Yang, Wen-Yang Zhou and Meng-Hao Guo etc. from the [Tsinghua CSCG Group](https://cg.cs.tsinghua.edu.cn/). If you are also interested in Jittor and want to improve it, Please join us!
|
||||
Jittor is currently maintained by the [Tsinghua CSCG Group](https://cg.cs.tsinghua.edu.cn/). If you are also interested in Jittor and want to improve it, Please join us!
|
||||
|
||||
Jittor目前由来自[清华大学计算机图形学组](https://cg.cs.tsinghua.edu.cn/)的梁盾,杨国烨,杨国炜,周文洋和国孟昊等博士生维护。 如果您也对Jittor感兴趣并希望对其进行改进,请加入我们!
|
||||
Jittor目前由[清华大学计算机图形学组](https://cg.cs.tsinghua.edu.cn/)维护。 如果您也对Jittor感兴趣并希望对其进行改进,请加入我们!
|
||||
|
||||
## Citation
|
||||
|
||||
## 引用
|
||||
|
||||
```
|
||||
@article{hu2020jittor,
|
||||
title={Jittor: a novel deep learning framework with meta-operators and unified graph execution},
|
||||
author={Hu, Shi-Min and Liang, Dun and Yang, Guo-Ye and Yang, Guo-Wei and Zhou, Wen-Yang},
|
||||
journal={Information Sciences},
|
||||
volume={63},
|
||||
number={222103},
|
||||
pages={1--222103},
|
||||
year={2020}
|
||||
}
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Guoye Yang <498731903@qq.com>
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
//
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Guoye Yang <498731903@qq.com>
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
//
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Guoye Yang <498731903@qq.com>
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
//
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Guoye Yang <498731903@qq.com>
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
//
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Xiangli Li <1905692338@qq.com>
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
//
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Xiangli Li <1905692338@qq.com>
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
//
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Guoye Yang <498731903@qq.com>
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
//
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Meng-Hao Guo <guomenghao1997@gmail.com>
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
//
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Meng-Hao Guo <guomenghao1997@gmail.com>
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
//
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Guowei Yang <471184555@qq.com>
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
//
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Guoye Yang <498731903@qq.com>
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
//
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Guoye Yang <498731903@qq.com>
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
//
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Dun Liang <randonlang@gmail.com>
|
||||
// Guowei Yang <471184555@qq.com>
|
||||
// All Rights Reserved.
|
||||
//
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Dun Liang <randonlang@gmail.com>
|
||||
// Guowei Yang <471184555@qq.com>
|
||||
// All Rights Reserved.
|
||||
//
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Dun Liang <randonlang@gmail.com>
|
||||
// Guowei Yang <471184555@qq.com>
|
||||
// All Rights Reserved.
|
||||
//
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Dun Liang <randonlang@gmail.com>
|
||||
// Guowei Yang <471184555@qq.com>
|
||||
// All Rights Reserved.
|
||||
//
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
//
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Guoye Yang <498731903@qq.com>
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
//
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Guoye Yang <498731903@qq.com>.
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
//
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Guoye Yang <498731903@qq.com>
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
//
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Guoye Yang <498731903@qq.com>.
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
//
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor.
|
||||
// Authors:
|
||||
// All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
//
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor.
|
||||
// Authors:
|
||||
// All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
//
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor.
|
||||
// Authors:
|
||||
// All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
//
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
#include <cuda_runtime.h>
|
||||
#include "helper_cuda.h"
|
||||
|
||||
|
||||
#ifdef _CUFFT_H_
|
||||
// cuFFT API errors
|
||||
const char *_cudaGetErrorEnum(cufftResult error) {
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Guowei Yang <471184555@qq.com>
|
||||
// Guoye Yang <498731903@qq.com>
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
//
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Guowei Yang <471184555@qq.com>
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
//
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Guowei Yang <471184555@qq.com>
|
||||
// Guoye Yang <498731903@qq.com>
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
//
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Guowei Yang <471184555@qq.com>
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
//
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Guowei Yang <471184555@qq.com>
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
//
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Guowei Yang <471184555@qq.com>
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
//
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Guowei Yang <471184555@qq.com>
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
//
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Guoye Yang <498731903@qq.com>
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
//
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor.
|
||||
// Authors:
|
||||
// All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
//
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor.
|
||||
// Authors:
|
||||
// All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
//
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
|
|
@ -1,13 +1,14 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
# Maintainers:
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# Meng-Hao Guo <guomenghao1997@gmail.com>
|
||||
#
|
||||
# All Rights Reserved.
|
||||
#
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
__version__ = '1.2.2.0'
|
||||
__version__ = '1.2.2.12'
|
||||
from . import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
@ -33,9 +34,38 @@ from collections import OrderedDict
|
|||
from collections.abc import Sequence, Mapping
|
||||
import types
|
||||
import pickle
|
||||
import sys
|
||||
import hashlib
|
||||
import sys, os
|
||||
import traceback
|
||||
|
||||
|
||||
def safepickle(obj, path):
|
||||
s = pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
|
||||
checksum = hashlib.sha1(s).digest()
|
||||
s += bytes(checksum)
|
||||
s += b"HCAJSLHD"
|
||||
with open(path, 'wb') as f:
|
||||
f.write(s)
|
||||
|
||||
def safeunpickle(path):
|
||||
if path.startswith("jittorhub://"):
|
||||
path = path.replace("jittorhub://", "https://cg.cs.tsinghua.edu.cn/jittor/assets/build/checkpoints/")
|
||||
if path.startswith("https:") or path.startswith("http:"):
|
||||
base = path.split("/")[-1]
|
||||
fname = os.path.join(compiler.ck_path, base)
|
||||
from jittor.utils.misc import download_url_to_local
|
||||
download_url_to_local(path, base, compiler.ck_path, None)
|
||||
path = fname
|
||||
with open(path, "rb") as f:
|
||||
s = f.read()
|
||||
if not s.endswith(b"HCAJSLHD"):
|
||||
return pickle.loads(s)
|
||||
checksum = s[-28:-8]
|
||||
s = s[:-28]
|
||||
if hashlib.sha1(s).digest() != checksum:
|
||||
raise ValueError("Pickle checksum does not match! path: "+path)
|
||||
return pickle.loads(s)
|
||||
|
||||
class _call_no_record_scope:
|
||||
def __enter__(self): pass
|
||||
def __exit__(self, *exc): pass
|
||||
|
@ -307,9 +337,9 @@ def flatten(input, start_dim=0, end_dim=-1):
|
|||
return input.reshape(out_shape)
|
||||
Var.flatten = flatten
|
||||
|
||||
def detach_inplace(x):
|
||||
return x.swap(x.stop_grad().clone())
|
||||
Var.start_grad = Var.detach_inplace = detach_inplace
|
||||
def start_grad(x):
|
||||
return x._update(x)
|
||||
Var.detach_inplace = Var.start_grad = start_grad
|
||||
|
||||
def detach(x):
|
||||
return x.detach()
|
||||
|
@ -436,10 +466,30 @@ def display_memory_info():
|
|||
core.display_memory_info(fileline)
|
||||
|
||||
def load(path):
|
||||
pkl_file = open(path, 'rb')
|
||||
model_dict = pickle.load(pkl_file)
|
||||
if path.endswith(".pth"):
|
||||
try:
|
||||
dirty_fix_pytorch_runtime_error()
|
||||
import torch
|
||||
except:
|
||||
raise RuntimeError("pytorch need to be installed when load pth format.")
|
||||
model_dict = torch.load(path, map_location=torch.device('cpu'))
|
||||
else:
|
||||
model_dict = safeunpickle(path)
|
||||
return model_dict
|
||||
|
||||
def save(params_dict, path):
|
||||
def dfs(x):
|
||||
if isinstance(x, list):
|
||||
for i in range(len(x)):
|
||||
x[i] = dfs(x[i])
|
||||
elif isinstance(x, dict):
|
||||
for k in x:
|
||||
x[k] = dfs(x[k])
|
||||
elif isinstance(x, Var):
|
||||
return x.numpy()
|
||||
return x
|
||||
safepickle(dfs(params_dict), path)
|
||||
|
||||
def _uniq(x):
|
||||
a = set()
|
||||
b = []
|
||||
|
@ -503,6 +553,7 @@ class Module:
|
|||
def callback(parents, k, v, n):
|
||||
stack.append(str(k))
|
||||
for k2, p in v.__dict__.items():
|
||||
if k2.startswith("_"): continue
|
||||
if isinstance(p, Var):
|
||||
ps.append(p)
|
||||
p.name(".".join(stack[1:]+[str(k2)]))
|
||||
|
@ -560,6 +611,21 @@ class Module:
|
|||
return ret
|
||||
self.__class__.__call__ = new_call
|
||||
|
||||
def register_pre_forward_hook(self, func):
|
||||
cls = self.__class__
|
||||
self.__fhook2__ = func
|
||||
if hasattr(cls, "__hooked2__"):
|
||||
return
|
||||
cls.__hooked2__ = True
|
||||
origin_call = cls.__call__
|
||||
def new_call(self, *args, **kw):
|
||||
if hasattr(self, "__fhook2__"):
|
||||
if len(kw):
|
||||
self.__fhook2__(self, args, kw)
|
||||
else:
|
||||
self.__fhook2__(self, args)
|
||||
return origin_call(self, *args, **kw)
|
||||
self.__class__.__call__ = new_call
|
||||
|
||||
def children(self):
|
||||
cd = []
|
||||
|
@ -632,20 +698,10 @@ class Module:
|
|||
params_dict = {}
|
||||
for p in params:
|
||||
params_dict[p.name()] = p.data
|
||||
with open(path, 'wb') as f:
|
||||
pickle.dump(params_dict, f, pickle.HIGHEST_PROTOCOL)
|
||||
safepickle(params_dict, path)
|
||||
|
||||
def load(self, path):
|
||||
if path.endswith(".pth"):
|
||||
try:
|
||||
dirty_fix_pytorch_runtime_error()
|
||||
import torch
|
||||
except:
|
||||
raise RuntimeError("pytorch need to be installed when load pth format.")
|
||||
self.load_parameters(torch.load(path, map_location=torch.device('cpu')))
|
||||
return
|
||||
with open(path, 'rb') as f:
|
||||
self.load_parameters(pickle.load(f))
|
||||
self.load_parameters(load(path))
|
||||
|
||||
def eval(self):
|
||||
def callback(parents, k, v, n):
|
||||
|
@ -790,6 +846,11 @@ can also be None)::
|
|||
def dfs(self, parents, k, callback, callback_leave=None):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def apply(cls, *args, **kw):
|
||||
func = cls()
|
||||
return func(*args, **kw)
|
||||
|
||||
|
||||
def make_module(func, exec_n_args=1):
|
||||
class MakeModule(Module):
|
||||
|
@ -889,7 +950,10 @@ def format(v, spec):
|
|||
return v.item().__format__(spec)
|
||||
Var.__format__ = format
|
||||
|
||||
def get_len(var):
|
||||
return var.shape[0]
|
||||
|
||||
Var.__len__ = get_len
|
||||
int = int32
|
||||
Var.int = Var.int32
|
||||
float = float32
|
||||
|
@ -906,3 +970,27 @@ from . import contrib
|
|||
from . import numpy2cupy
|
||||
from .contrib import concat
|
||||
from .misc import *
|
||||
from . import sparse
|
||||
|
||||
|
||||
def randn(*size, dtype="float32", requires_grad=False):
|
||||
if isinstance(size, tuple) and isinstance(size[0], tuple): size = size[0]
|
||||
arr = jt.random(size, dtype, "normal")
|
||||
if not requires_grad: return arr.stop_grad()
|
||||
return arr
|
||||
|
||||
def rand(*size, dtype="float32", requires_grad=False):
|
||||
if isinstance(size, tuple) and isinstance(size[0], tuple): size = size[0]
|
||||
arr = jt.random(size, dtype)
|
||||
if not requires_grad: return arr.stop_grad()
|
||||
return arr
|
||||
|
||||
def normal(mean, std, size=None, dtype="float32"):
|
||||
if size is None:
|
||||
if isinstance(mean, Var) and isinstance(std, Var):
|
||||
assert mean.shape == std.shape
|
||||
size = mean.shape
|
||||
else:
|
||||
if isinstance(mean, Var): size = mean.shape
|
||||
if isinstance(std, Var): size = std.shape
|
||||
return jt.init.gauss(size, dtype, mean, std)
|
|
@ -1,9 +1,10 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
# Maintainers:
|
||||
# Guowei Yang <471184555@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
#
|
||||
# All Rights Reserved.
|
||||
#
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
# Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
# Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
@ -138,6 +139,10 @@ def setup_cuda_extern():
|
|||
import traceback
|
||||
line = traceback.format_exc()
|
||||
LOG.w(f"CUDA found but {lib_name} is not loaded:\n{line}")
|
||||
if lib_name == "cudnn":
|
||||
LOG.w(f"Develop version of CUDNN not found, "
|
||||
"please refer to CUDA offical tar file installation: "
|
||||
"https://docs.nvidia.com/deeplearning/cudnn/install-guide/index.html#installlinux-tar")
|
||||
|
||||
def setup_cuda_lib(lib_name, link=True, extra_flags=""):
|
||||
globals()[lib_name+"_ops"] = None
|
||||
|
@ -183,7 +188,7 @@ def install_cutt(root_folder):
|
|||
filename = "cutt-master.zip"
|
||||
fullname = os.path.join(root_folder, filename)
|
||||
dirname = os.path.join(root_folder, filename.replace(".zip",""))
|
||||
true_md5 = "a6f4f7f75310a69b131e21f1ebec768a"
|
||||
true_md5 = "af5bc35eea1832a42c0e0011659b7209"
|
||||
|
||||
if os.path.exists(fullname):
|
||||
md5 = run_cmd('md5sum '+fullname).split()[0]
|
||||
|
@ -205,7 +210,11 @@ def install_cutt(root_folder):
|
|||
zf.close()
|
||||
|
||||
LOG.i("installing cutt...")
|
||||
run_cmd(f"make", cwd=dirname)
|
||||
arch_flag = ""
|
||||
if len(flags.cuda_archs):
|
||||
arch_flag = f" -arch=compute_{min(flags.cuda_archs)} "
|
||||
arch_flag += ''.join(map(lambda x:f' -code=sm_{x} ', flags.cuda_archs))
|
||||
run_cmd(f"make NVCC_GENCODE='{arch_flag}' nvcc_path='{nvcc_path}'", cwd=dirname)
|
||||
return dirname
|
||||
|
||||
def setup_cutt():
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
# Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
# Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
@ -74,10 +75,17 @@ def compile(compiler, flags, inputs, output, combind_build=False):
|
|||
for input, obj_file in zip(inputs, obj_files):
|
||||
cc = compiler
|
||||
nflags = oflags
|
||||
if has_cuda and input.endswith(".cu"):
|
||||
nflags = convert_nvcc_flags(oflags)
|
||||
cc = nvcc_path
|
||||
if input.endswith(".cu"):
|
||||
if has_cuda:
|
||||
nflags = convert_nvcc_flags(oflags)
|
||||
cc = nvcc_path
|
||||
else:
|
||||
continue
|
||||
cmd = f"{cc} {input} {nflags} -c {lto_flags} -o {obj_file}"
|
||||
if "nan_checker" in input:
|
||||
# nan checker needs to disable fast_math
|
||||
cmd = cmd.replace("--use_fast_math", "")
|
||||
cmd = cmd.replace("-Ofast", "-O2")
|
||||
cmds.append(cmd)
|
||||
jit_utils.run_cmds(cmds, cache_path, jittor_path, "Compiling "+base_output)
|
||||
cmd = f"{compiler} {' '.join(obj_files)} {flags} {lto_flags} {link} -o {output}"
|
||||
|
@ -894,6 +902,8 @@ make_cache_dir(cache_path)
|
|||
make_cache_dir(os.path.join(cache_path, "jit"))
|
||||
make_cache_dir(os.path.join(cache_path, "obj_files"))
|
||||
make_cache_dir(os.path.join(cache_path, "gen"))
|
||||
ck_path = os.path.join(cache_path, "checkpoints")
|
||||
make_cache_dir(ck_path)
|
||||
|
||||
# build cache_compile
|
||||
cc_flags += f" -I{jittor_path}/src "
|
||||
|
@ -943,7 +953,8 @@ pyjt_gen_src = pyjt_compiler.compile(cache_path, jittor_path)
|
|||
# 3. op_utils
|
||||
# 4. other
|
||||
files2 = pyjt_gen_src
|
||||
files4 = run_cmd('find -L src | grep "cc$"', jittor_path).splitlines()
|
||||
grep_args = '"c[cu]$"' if has_cuda else '"cc$"'
|
||||
files4 = run_cmd('find -L src | grep '+grep_args, jittor_path).splitlines()
|
||||
at_beginning = [
|
||||
"src/ops/op_utils.cc",
|
||||
"src/event_queue.cc",
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
# Maintainers:
|
||||
# Guowei Yang <471184555@qq.com>
|
||||
# Guoye Yang <498731903@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
#
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
@ -12,6 +13,7 @@ import numpy as np
|
|||
from jittor import pool
|
||||
from collections.abc import Sequence
|
||||
|
||||
|
||||
def argmax_pool(x, size, stride, padding=0):
|
||||
return pool.pool(x, size, 'maximum', padding, stride)
|
||||
|
||||
|
@ -196,28 +198,37 @@ def setitem(x, slices, value):
|
|||
mask = jt.broadcast(slices, x)
|
||||
value = jt.broadcast(value, x)
|
||||
return x.assign(mask.ternary(value, x))
|
||||
if isinstance(slices, list):
|
||||
slices = tuple(slices)
|
||||
if isinstance(slices, Sequence):
|
||||
ss = []
|
||||
for s in slices:
|
||||
if isinstance(s, jt.Var) and s.dtype == "bool":
|
||||
ss.extend(s.where())
|
||||
else:
|
||||
ss.append(s)
|
||||
slices = tuple(ss)
|
||||
return x.assign(x.setitem(slices, value))
|
||||
|
||||
jt.Var.__getitem__ = jt.Var.slice_var = getitem
|
||||
jt.Var.__setitem__ = setitem
|
||||
|
||||
def concat(arr, dim):
|
||||
def concat(arr, dim=0):
|
||||
'''Concat Operator can concat a list of jt Var at a specfic dimension.
|
||||
|
||||
* [in] x: input var list for concat
|
||||
|
||||
* [in] dim: concat which dim
|
||||
|
||||
* [out] out: concat result
|
||||
* return: concat result
|
||||
|
||||
Example::
|
||||
|
||||
jt.concat([jt.array([[1],[2]]), jt.array([[2],[2]])], dim=1)
|
||||
# return [[1],[2],[2],[2]]
|
||||
'''
|
||||
# TODO: low performance when concat lots of vars
|
||||
if not isinstance(arr, Sequence):
|
||||
raise TypeError("concat arr needs to be a tuple or list")
|
||||
if len(arr) == 0:
|
||||
raise ValueError("need at least one array to concat")
|
||||
total_dim = 0
|
||||
if dim < 0: dim += len(arr[0].shape)
|
||||
for a in arr:
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
# Maintainers:
|
||||
# Meng-Hao Guo <guomenghao1997@gmail.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
#
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
@ -27,8 +28,9 @@ mpi = jt.mpi
|
|||
img_open_hook = HookTimer(Image, "open")
|
||||
|
||||
class Worker:
|
||||
def __init__(self, target, args, buffer_size):
|
||||
def __init__(self, target, args, buffer_size, keep_numpy_array=False):
|
||||
self.buffer = jt.RingBuffer(buffer_size)
|
||||
self.buffer.keep_numpy_array(keep_numpy_array)
|
||||
|
||||
self.status = mp.Array('f', 5, lock=False)
|
||||
self.p = mp.Process(target=target, args=args+(self.buffer,self.status))
|
||||
|
@ -67,7 +69,8 @@ class Dataset(object):
|
|||
drop_last = False,
|
||||
num_workers = 0,
|
||||
buffer_size = 512*1024*1024,
|
||||
stop_grad = True):
|
||||
stop_grad = True,
|
||||
keep_numpy_array = False):
|
||||
super().__init__()
|
||||
self.total_len = None
|
||||
self.batch_size = batch_size
|
||||
|
@ -76,6 +79,7 @@ class Dataset(object):
|
|||
self.num_workers = num_workers
|
||||
self.buffer_size = buffer_size
|
||||
self.stop_grad = stop_grad
|
||||
self.keep_numpy_array = keep_numpy_array
|
||||
|
||||
def __getitem__(self, index):
|
||||
raise NotImplementedError
|
||||
|
@ -101,7 +105,7 @@ class Dataset(object):
|
|||
Attrs:
|
||||
|
||||
* batch_size(int): batch size, default 16.
|
||||
* totol_len(int): totol lenght.
|
||||
* total_len(int): total lenght.
|
||||
* shuffle(bool): shuffle at each epoch, default False.
|
||||
* drop_last(bool): if true, the last batch of dataset might smaller than batch_size, default True.
|
||||
* num_workers: number of workers for loading data
|
||||
|
@ -111,12 +115,15 @@ class Dataset(object):
|
|||
for k,v in kw.items():
|
||||
assert hasattr(self, k), k
|
||||
setattr(self, k, v)
|
||||
self.reset()
|
||||
return self
|
||||
|
||||
def to_jittor(self, batch):
|
||||
'''
|
||||
Change batch data to jittor array, such as np.ndarray, int, and float.
|
||||
'''
|
||||
if self.keep_numpy_array: return batch
|
||||
if isinstance(batch, jt.Var): return batch
|
||||
to_jt = lambda x: jt.array(x).stop_grad() \
|
||||
if self.stop_grad else jt.array(x)
|
||||
if isinstance(batch, np.ndarray):
|
||||
|
@ -129,7 +136,7 @@ class Dataset(object):
|
|||
isinstance(a, float):
|
||||
new_batch.append(to_jt(a))
|
||||
else:
|
||||
new_batch.append(a)
|
||||
new_batch.append(self.to_jittor(a))
|
||||
return new_batch
|
||||
|
||||
def collate_batch(self, batch):
|
||||
|
@ -299,11 +306,26 @@ Example::
|
|||
self.num_idle_c = mp.Condition(self.gid.get_lock())
|
||||
for i in range(self.num_workers):
|
||||
w = Worker(target=self._worker_main, args=(i,),
|
||||
buffer_size=self.buffer_size)
|
||||
buffer_size=self.buffer_size,
|
||||
keep_numpy_array=self.keep_numpy_array)
|
||||
workers.append(w)
|
||||
self.workers = workers
|
||||
self.index_list_numpy = np.ndarray(dtype='int32', shape=self.real_len, buffer=self.index_list)
|
||||
|
||||
def reset(self):
|
||||
if not hasattr(self, "workers"):
|
||||
return
|
||||
self._stop_all_workers()
|
||||
self.terminate()
|
||||
del self.index_list
|
||||
del self.idmap
|
||||
del self.gid
|
||||
del self.gidc
|
||||
del self.num_idle
|
||||
del self.num_idle_c
|
||||
del self.workers
|
||||
del self.index_list_numpy
|
||||
|
||||
def __del__(self):
|
||||
if mp_log_v:
|
||||
print("dataset deleted")
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
# Maintainers:
|
||||
# Meng-Hao Guo <guomenghao1997@gmail.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
#
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
# Maintainers:
|
||||
# Guoye Yang <498731903@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
#
|
||||
# All Rights Reserved.
|
||||
#
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
@ -308,7 +309,7 @@ class DepthwiseConv(Function):
|
|||
|
||||
grid = dim3(ksize_width, ksize_height, output_channels);
|
||||
threads = dim3(std::min(output_width, block_size), crop_output_height, 1);
|
||||
|
||||
cudaMemsetAsync(filter_grad_p, 0, filter_grad->size);
|
||||
|
||||
KernelDepthwiseConvFilterGrad<
|
||||
input_type><<<grid, threads, 0>>>(
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
# Maintainers:
|
||||
# Guowei Yang <471184555@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
#
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
# Maintainers:
|
||||
# Guowei Yang <471184555@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
#
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
# Authors:
|
||||
# Maintainers:
|
||||
# Haoyang Peng <2247838039@qq.com>
|
||||
# Guowei Yang <471184555@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
# Maintainers:
|
||||
# Guowei Yang <471184555@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
#
|
||||
# All Rights Reserved.
|
||||
#
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
# Maintainers:
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# Wenyang Zhou <576825820@qq.com>
|
||||
#
|
||||
# All Rights Reserved.
|
||||
#
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
@ -96,6 +97,27 @@ def repeat(x, *shape):
|
|||
|
||||
jt.Var.repeat = repeat
|
||||
|
||||
def repeat_interleave(x,repeats,dim=None):
|
||||
# TODO repeats is jt.Var
|
||||
assert isinstance(repeats,int)
|
||||
if dim == None:
|
||||
x = x.reshape(-1)
|
||||
dim=0
|
||||
if dim<0: dim+=x.ndim
|
||||
|
||||
tar_shape = list(x.shape)
|
||||
x_shape = list(x.shape)
|
||||
tar_shape[dim] = tar_shape[dim]*repeats
|
||||
dims = []
|
||||
for i in range(len(tar_shape)):
|
||||
if dim==i:
|
||||
dims.append(f"i{i}/{repeats}")
|
||||
else:
|
||||
dims.append(f"i{i}")
|
||||
return x.reindex(tar_shape,dims)
|
||||
|
||||
jt.Var.repeat_interleave = repeat_interleave
|
||||
|
||||
def chunk(x, chunks, dim=0):
|
||||
r'''
|
||||
Splits a var into a specific number of chunks. Each chunk is a view of the input var.
|
||||
|
@ -209,15 +231,18 @@ def flip(x, dim=0):
|
|||
>>> x.flip(1)
|
||||
[[4 3 2 1]]
|
||||
'''
|
||||
assert isinstance(dim, int)
|
||||
if dim<0:
|
||||
dim+=x.ndim
|
||||
assert dim>=0 and dim<len(x.shape)
|
||||
if isinstance(dim, int):
|
||||
dim = [dim]
|
||||
for i in range(len(dim)):
|
||||
if dim[i]<0:
|
||||
dim[i] += x.ndim
|
||||
assert dim[i]>=0 and dim[i]<x.ndim
|
||||
dim = set(dim)
|
||||
|
||||
tar_dims = []
|
||||
for i in range(len(x.shape)):
|
||||
if i == dim:
|
||||
tar_dims.append(f"{x.shape[dim]-1}-i{i}")
|
||||
if i in dim:
|
||||
tar_dims.append(f"xshape{i}-1-i{i}")
|
||||
else:
|
||||
tar_dims.append(f"i{i}")
|
||||
return x.reindex(x.shape, tar_dims)
|
||||
|
@ -335,16 +360,37 @@ def unbind(x, dim=0):
|
|||
jt.Var.unbind = unbind
|
||||
|
||||
def make_grid(x, nrow=8, padding=2, normalize=False, range=None, scale_each=False, pad_value=0):
|
||||
assert range == None
|
||||
assert isinstance(range, tuple) or range is None
|
||||
assert scale_each == False
|
||||
if isinstance(x, list): x = jt.stack(x)
|
||||
if normalize: x = (x - x.min()) / (x.max() - x.min())
|
||||
if normalize:
|
||||
if range is None: x = (x - x.min()) / (x.max() - x.min())
|
||||
else: x = (x - range[0]) / (range[1] - range[0])
|
||||
b,c,h,w = x.shape
|
||||
ncol = math.ceil(b / nrow)
|
||||
return x.reindex([c, h*ncol+(ncol+1)*padding, w*nrow+(nrow+1)*padding],
|
||||
[f"i1/{padding+h}*{nrow}+i2/{padding+w}", "i0",
|
||||
f"i1-i1/{padding+h}*{padding+h}-{padding}", f"i2-i2/{padding+w}*{padding+w}-{padding}"], overflow_value=pad_value)
|
||||
|
||||
def save_image(
|
||||
x,
|
||||
filepath,
|
||||
nrow: int = 8,
|
||||
padding: int = 2,
|
||||
normalize: bool = False,
|
||||
range = None,
|
||||
scale_each = False,
|
||||
pad_value = 0,
|
||||
format = None
|
||||
):
|
||||
from PIL import Image
|
||||
grid = make_grid(x, nrow=nrow, padding=padding, pad_value=pad_value,
|
||||
normalize=normalize, range=range, scale_each=scale_each)
|
||||
|
||||
ndarr = (grid*255+0.5).clamp(0, 255).permute(1, 2, 0).uint8().numpy()
|
||||
im = Image.fromarray(ndarr)
|
||||
im.save(filepath, format=format)
|
||||
|
||||
|
||||
def _ntuple(n):
|
||||
def parse(x):
|
||||
|
@ -582,12 +628,11 @@ def gather(x,dim,index):
|
|||
return x.reindex(ins)
|
||||
jt.Var.gather = gather
|
||||
|
||||
def prod(x,dim=0):
|
||||
def _prod(x,dim=0):
|
||||
x = jt.log(x)
|
||||
x = x.sum(dim=dim)
|
||||
return jt.exp(x)
|
||||
|
||||
jt.Var.prod = prod
|
||||
|
||||
def cumsum_forward(np, data):
|
||||
a = data['inputs'][0]
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
# Maintainers:
|
||||
# Wenyang Zhou <576825820@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
#
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
@ -61,6 +62,7 @@ class AlexNet(nn.Module):
|
|||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
def alexnet(**kwargs):
|
||||
def alexnet(pretrained=False, **kwargs):
|
||||
model = AlexNet(**kwargs)
|
||||
if pretrained: model.load("jittorhub://alexnet.pkl")
|
||||
return model
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
# Maintainers:
|
||||
# Guowei Yang <471184555@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
#
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
@ -21,7 +22,7 @@ def densenet121(pretrained=False, **kwargs):
|
|||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
'''
|
||||
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), **kwargs)
|
||||
assert not pretrained, "pretrained doesn't support now"
|
||||
if pretrained: model.load("jittorhub://densenet121.pkl")
|
||||
return model
|
||||
|
||||
def densenet161(pretrained=False, **kwargs):
|
||||
|
@ -32,7 +33,7 @@ def densenet161(pretrained=False, **kwargs):
|
|||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
'''
|
||||
model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), **kwargs)
|
||||
assert not pretrained, "pretrained doesn't support now"
|
||||
if pretrained: model.load("jittorhub://densenet161.pkl")
|
||||
return model
|
||||
|
||||
def densenet169(pretrained=False, **kwargs):
|
||||
|
@ -43,7 +44,7 @@ def densenet169(pretrained=False, **kwargs):
|
|||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
'''
|
||||
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), **kwargs)
|
||||
assert not pretrained, "pretrained doesn't support now"
|
||||
if pretrained: model.load("jittorhub://densenet169.pkl")
|
||||
return model
|
||||
|
||||
def densenet201(pretrained=False, **kwargs):
|
||||
|
@ -54,7 +55,7 @@ def densenet201(pretrained=False, **kwargs):
|
|||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
'''
|
||||
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), **kwargs)
|
||||
assert not pretrained, "pretrained doesn't support now"
|
||||
if pretrained: model.load("jittorhub://densenet201.pkl")
|
||||
return model
|
||||
|
||||
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
# Maintainers:
|
||||
# Wenyang Zhou <576825820@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
#
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
@ -12,8 +13,10 @@ from jittor import nn
|
|||
|
||||
__all__ = ['GoogLeNet', 'googlenet']
|
||||
|
||||
def googlenet(**kwargs):
|
||||
return GoogLeNet(**kwargs)
|
||||
def googlenet(pretrained=False, **kwargs):
|
||||
model = GoogLeNet(**kwargs)
|
||||
if pretrained: model.load("jittorhub://googlenet.pkl")
|
||||
return model
|
||||
|
||||
class GoogLeNet(nn.Module):
|
||||
""" GoogLeNet model architecture.
|
||||
|
|
|
@ -4,7 +4,9 @@ from jittor import nn
|
|||
__all__ = ['Inception3', 'inception_v3']
|
||||
|
||||
def inception_v3(pretrained=False, progress=True, **kwargs):
|
||||
return Inception3(**kwargs)
|
||||
model = Inception3(**kwargs)
|
||||
if pretrained: model.load("jittorhub://inception_v3.pkl")
|
||||
return model
|
||||
|
||||
class Inception3(nn.Module):
|
||||
""" Inceptionv3 model architecture.
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
# Maintainers:
|
||||
# Wenyang Zhou <576825820@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
#
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
@ -90,18 +91,22 @@ class MNASNet(nn.Module):
|
|||
x = x.mean([2, 3])
|
||||
return self.classifier(x)
|
||||
|
||||
def mnasnet0_5(**kwargs):
|
||||
def mnasnet0_5(pretrained=False, **kwargs):
|
||||
model = MNASNet(0.5, **kwargs)
|
||||
if pretrained: model.load("jittorhub://mnasnet0_5.pkl")
|
||||
return model
|
||||
|
||||
def mnasnet0_75(**kwargs):
|
||||
def mnasnet0_75(pretrained=False, **kwargs):
|
||||
model = MNASNet(0.75, **kwargs)
|
||||
if pretrained: model.load("jittorhub://mnasnet0_75.pkl")
|
||||
return model
|
||||
|
||||
def mnasnet1_0(**kwargs):
|
||||
def mnasnet1_0(pretrained=False, **kwargs):
|
||||
model = MNASNet(1.0, **kwargs)
|
||||
if pretrained: model.load("jittorhub://mnasnet1_0.pkl")
|
||||
return model
|
||||
|
||||
def mnasnet1_3(**kwargs):
|
||||
def mnasnet1_3(pretrained=False, **kwargs):
|
||||
model = MNASNet(1.3, **kwargs)
|
||||
if pretrained: model.load("jittorhub://mnasnet1_3.pkl")
|
||||
return model
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
# Maintainers:
|
||||
# Wenyang Zhou <576825820@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
#
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
@ -93,7 +94,8 @@ class MobileNetV2(nn.Module):
|
|||
def execute(self, x):
|
||||
return self._forward_impl(x)
|
||||
|
||||
def mobilenet_v2():
|
||||
def mobilenet_v2(pretrained=False):
|
||||
model = MobileNetV2()
|
||||
if pretrained: model.load("jittorhub://mobilenet_v2.pkl")
|
||||
return model
|
||||
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
# Maintainers:
|
||||
# Guowei Yang <471184555@qq.com>
|
||||
# Wenyang Zhou <576825820@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
#
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
@ -154,16 +155,23 @@ def _resnet(block, layers, **kwargs):
|
|||
model = ResNet(block, layers, **kwargs)
|
||||
return model
|
||||
|
||||
def Resnet18(**kwargs):
|
||||
return _resnet(BasicBlock, [2, 2, 2, 2], **kwargs)
|
||||
def Resnet18(pretrained=False, **kwargs):
|
||||
model = _resnet(BasicBlock, [2, 2, 2, 2], **kwargs)
|
||||
if pretrained: model.load("jittorhub://resnet18.pkl")
|
||||
return model
|
||||
resnet18 = Resnet18
|
||||
|
||||
def Resnet34(**kwargs):
|
||||
return _resnet( BasicBlock, [3, 4, 6, 3], **kwargs)
|
||||
def Resnet34(pretrained=False, **kwargs):
|
||||
model = _resnet(BasicBlock, [3, 4, 6, 3], **kwargs)
|
||||
if pretrained: model.load("jittorhub://resnet34.pkl")
|
||||
return model
|
||||
resnet34 = Resnet34
|
||||
|
||||
def Resnet50(**kwargs):
|
||||
return _resnet(Bottleneck, [3, 4, 6, 3], **kwargs)
|
||||
def Resnet50(pretrained=False, **kwargs):
|
||||
model = _resnet(Bottleneck, [3, 4, 6, 3], **kwargs)
|
||||
if pretrained: model.load("jittorhub://resnet50.pkl")
|
||||
return model
|
||||
|
||||
resnet50 = Resnet50
|
||||
|
||||
def Resnet38(**kwargs):
|
||||
|
@ -174,7 +182,7 @@ def Resnet26(**kwargs):
|
|||
return _resnet(Bottleneck, [1, 2, 4, 1], **kwargs)
|
||||
resnet26 = Resnet26
|
||||
|
||||
def Resnet101(**kwargs):
|
||||
def Resnet101(pretrained=False, **kwargs):
|
||||
"""
|
||||
ResNet-101 model architecture.
|
||||
|
||||
|
@ -188,28 +196,38 @@ def Resnet101(**kwargs):
|
|||
return _resnet(Bottleneck, [3, 4, 23, 3], **kwargs)
|
||||
resnet101 = Resnet101
|
||||
|
||||
def Resnet152(**kwargs):
|
||||
return _resnet(Bottleneck, [3, 8, 36, 3], **kwargs)
|
||||
def Resnet152(pretrained=False, **kwargs):
|
||||
model = _resnet(Bottleneck, [3, 8, 36, 3], **kwargs)
|
||||
if pretrained: model.load("jittorhub://resnet152.pkl")
|
||||
return model
|
||||
resnet152 = Resnet152
|
||||
|
||||
def Resnext50_32x4d(**kwargs):
|
||||
def Resnext50_32x4d(pretrained=False, **kwargs):
|
||||
kwargs['groups'] = 32
|
||||
kwargs['width_per_group'] = 4
|
||||
return _resnet(Bottleneck, [3, 4, 6, 3], **kwargs)
|
||||
model = _resnet(Bottleneck, [3, 4, 6, 3], **kwargs)
|
||||
if pretrained: model.load("jittorhub://resnext50_32x4d.pkl")
|
||||
return model
|
||||
resnext50_32x4d = Resnext50_32x4d
|
||||
|
||||
def Resnext101_32x8d(**kwargs):
|
||||
def Resnext101_32x8d(pretrained=False, **kwargs):
|
||||
kwargs['groups'] = 32
|
||||
kwargs['width_per_group'] = 8
|
||||
return _resnet(Bottleneck, [3, 4, 23, 3], **kwargs)
|
||||
model = _resnet(Bottleneck, [3, 4, 23, 3], **kwargs)
|
||||
if pretrained: model.load("jittorhub://resnext101_32x8d.pkl")
|
||||
return model
|
||||
resnext101_32x8d = Resnext101_32x8d
|
||||
|
||||
def Wide_resnet50_2(**kwargs):
|
||||
def Wide_resnet50_2(pretrained=False, **kwargs):
|
||||
kwargs['width_per_group'] = (64 * 2)
|
||||
return _resnet(Bottleneck, [3, 4, 6, 3], **kwargs)
|
||||
model = _resnet(Bottleneck, [3, 4, 6, 3], **kwargs)
|
||||
if pretrained: model.load("jittorhub://wide_resnet50_2.pkl")
|
||||
return model
|
||||
wide_resnet50_2 = Wide_resnet50_2
|
||||
|
||||
def Wide_resnet101_2(**kwargs):
|
||||
def Wide_resnet101_2(pretrained=False, **kwargs):
|
||||
kwargs['width_per_group'] = (64 * 2)
|
||||
return _resnet(Bottleneck, [3, 4, 23, 3], **kwargs)
|
||||
model = _resnet(Bottleneck, [3, 4, 23, 3], **kwargs)
|
||||
if pretrained: model.load("jittorhub://wide_resnet101_2.pkl")
|
||||
return model
|
||||
wide_resnet101_2 = Wide_resnet101_2
|
|
@ -1,9 +1,10 @@
|
|||
|
||||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
# Maintainers:
|
||||
# Wenyang Zhou <576825820@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
#
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
@ -93,14 +94,22 @@ def _shufflenetv2(arch, *args):
|
|||
model = ShuffleNetV2(*args)
|
||||
return model
|
||||
|
||||
def shufflenet_v2_x0_5():
|
||||
return _shufflenetv2('shufflenetv2_x0.5', [4, 8, 4], [24, 48, 96, 192, 1024])
|
||||
def shufflenet_v2_x0_5(pretrained=False):
|
||||
model = _shufflenetv2('shufflenetv2_x0.5', [4, 8, 4], [24, 48, 96, 192, 1024])
|
||||
if pretrained: model.load("jittorhub://shufflenet_v2_x0_5.pkl")
|
||||
return model
|
||||
|
||||
def shufflenet_v2_x1_0():
|
||||
return _shufflenetv2('shufflenetv2_x1.0', [4, 8, 4], [24, 116, 232, 464, 1024])
|
||||
def shufflenet_v2_x1_0(pretrained=False):
|
||||
model = _shufflenetv2('shufflenetv2_x1.0', [4, 8, 4], [24, 116, 232, 464, 1024])
|
||||
if pretrained: model.load("jittorhub://shufflenet_v2_x1_0.pkl")
|
||||
return model
|
||||
|
||||
def shufflenet_v2_x1_5():
|
||||
return _shufflenetv2('shufflenetv2_x1.5', [4, 8, 4], [24, 176, 352, 704, 1024])
|
||||
def shufflenet_v2_x1_5(pretrained=False):
|
||||
model = _shufflenetv2('shufflenetv2_x1.5', [4, 8, 4], [24, 176, 352, 704, 1024])
|
||||
if pretrained: model.load("jittorhub://shufflenet_v2_x1_5.pkl")
|
||||
return model
|
||||
|
||||
def shufflenet_v2_x2_0():
|
||||
return _shufflenetv2('shufflenetv2_x2.0', [4, 8, 4], [24, 244, 488, 976, 2048])
|
||||
def shufflenet_v2_x2_0(pretrained=False):
|
||||
model = _shufflenetv2('shufflenetv2_x2.0', [4, 8, 4], [24, 244, 488, 976, 2048])
|
||||
if pretrained: model.load("jittorhub://shufflenet_v2_x2_0.pkl")
|
||||
return model
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
# Maintainers:
|
||||
# Wenyang Zhou <576825820@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
#
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
@ -83,8 +84,12 @@ def _squeezenet(version, **kwargs):
|
|||
model = SqueezeNet(version, **kwargs)
|
||||
return model
|
||||
|
||||
def squeezenet1_0(**kwargs):
|
||||
return _squeezenet('1_0', **kwargs)
|
||||
def squeezenet1_0(pretrained=False, **kwargs):
|
||||
model = _squeezenet('1_0', **kwargs)
|
||||
if pretrained: model.load("jittorhub://squeezenet1_0.pkl")
|
||||
return model
|
||||
|
||||
def squeezenet1_1(**kwargs):
|
||||
return _squeezenet('1_1', **kwargs)
|
||||
def squeezenet1_1(pretrained=False, **kwargs):
|
||||
model = _squeezenet('1_1', **kwargs)
|
||||
if pretrained: model.load("jittorhub://squeezenet1_1.pkl")
|
||||
return model
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
# Maintainers:
|
||||
# Guoye Yang <498731903@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
#
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
@ -67,33 +68,49 @@ def _vgg(arch, cfg, batch_norm, **kwargs):
|
|||
return model
|
||||
|
||||
|
||||
def vgg11(**kwargs):
|
||||
return _vgg('vgg11', 'A', False, **kwargs)
|
||||
def vgg11(pretrained=False, **kwargs):
|
||||
model = _vgg('vgg11', 'A', False, **kwargs)
|
||||
if pretrained: model.load("jittorhub://vgg11.pkl")
|
||||
return model
|
||||
|
||||
|
||||
def vgg11_bn(**kwargs):
|
||||
return _vgg('vgg11_bn', 'A', True, **kwargs)
|
||||
def vgg11_bn(pretrained=False, **kwargs):
|
||||
model = _vgg('vgg11_bn', 'A', True, **kwargs)
|
||||
if pretrained: model.load("jittorhub://vgg11_bn.pkl")
|
||||
return model
|
||||
|
||||
|
||||
def vgg13(**kwargs):
|
||||
return _vgg('vgg13', 'B', False, **kwargs)
|
||||
def vgg13(pretrained=False, **kwargs):
|
||||
model = _vgg('vgg13', 'B', False, **kwargs)
|
||||
if pretrained: model.load("jittorhub://vgg13.pkl")
|
||||
return model
|
||||
|
||||
|
||||
def vgg13_bn(**kwargs):
|
||||
return _vgg('vgg13_bn', 'B', True, **kwargs)
|
||||
def vgg13_bn(pretrained=False, **kwargs):
|
||||
model = _vgg('vgg13_bn', 'B', True, **kwargs)
|
||||
if pretrained: model.load("jittorhub://vgg13_bn.pkl")
|
||||
return model
|
||||
|
||||
|
||||
def vgg16(**kwargs):
|
||||
return _vgg('vgg16', 'D', False, **kwargs)
|
||||
def vgg16(pretrained=False, **kwargs):
|
||||
model = _vgg('vgg16', 'D', False, **kwargs)
|
||||
if pretrained: model.load("jittorhub://vgg16.pkl")
|
||||
return model
|
||||
|
||||
|
||||
def vgg16_bn(**kwargs):
|
||||
return _vgg('vgg16_bn', 'D', True, **kwargs)
|
||||
def vgg16_bn(pretrained=False, **kwargs):
|
||||
model = _vgg('vgg16_bn', 'D', True, **kwargs)
|
||||
if pretrained: model.load("jittorhub://vgg16_bn.pkl")
|
||||
return model
|
||||
|
||||
|
||||
def vgg19(**kwargs):
|
||||
return _vgg('vgg19', 'E', False, **kwargs)
|
||||
def vgg19(pretrained=False, **kwargs):
|
||||
model = _vgg('vgg19', 'E', False, **kwargs)
|
||||
if pretrained: model.load("jittorhub://vgg19.pkl")
|
||||
return model
|
||||
|
||||
|
||||
def vgg19_bn(**kwargs):
|
||||
return _vgg('vgg19_bn', 'E', True, **kwargs)
|
||||
def vgg19_bn(pretrained=False, **kwargs):
|
||||
model = _vgg('vgg19_bn', 'E', True, **kwargs)
|
||||
if pretrained: model.load("jittorhub://vgg19_bn.pkl")
|
||||
return model
|
|
@ -1,12 +1,13 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
# Maintainers:
|
||||
# Guowei Yang <471184555@qq.com>
|
||||
# Guoye Yang <498731903@qq.com>
|
||||
# Wenyang Zhou <576825820@qq.com>
|
||||
# Meng-Hao Guo <guomenghao1997@gmail.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
#
|
||||
# All Rights Reserved.
|
||||
#
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
@ -16,7 +17,7 @@ import numpy as np
|
|||
import collections
|
||||
import math
|
||||
from collections import OrderedDict
|
||||
from jittor.pool import Pool, pool, AdaptiveAvgPool2d
|
||||
from jittor.pool import *
|
||||
from jittor.optim import *
|
||||
from jittor.misc import _pair
|
||||
|
||||
|
@ -154,6 +155,7 @@ def get_init_var_rand(shape, dtype):
|
|||
def relu(x): return jt.ternary((x>0.0), x, jt.broadcast_var(0.0, x))
|
||||
def leaky_relu(x, scale=0.01): return jt.ternary(x>0, x, x*scale)
|
||||
def relu6(x): return jt.minimum(jt.maximum(x, 0.0), 6.0)
|
||||
def elu(x,alpha=1.0):return jt.ternary(x>0,x,alpha*(x.exp()-1))
|
||||
def sign(x):
|
||||
one = jt.ones(x.shape)
|
||||
x = jt.ternary(x>0, one, x)
|
||||
|
@ -165,6 +167,13 @@ def gelu(x):
|
|||
r = erf*x*.5
|
||||
return r
|
||||
|
||||
class ELU(Module):
|
||||
def __init__(self,alpha=1.0):
|
||||
self.alpha=alpha
|
||||
|
||||
def execute(self,x):
|
||||
return elu(x,self.alpha)
|
||||
|
||||
class PReLU(Module):
|
||||
def __init__(self, num_parameters=1, init_=0.25):
|
||||
self.num_parameters = num_parameters
|
||||
|
@ -238,6 +247,30 @@ def smooth_l1_loss(y_true, y_pred,reduction="mean"):
|
|||
else:
|
||||
raise ValueError(f'not support {reduction}')
|
||||
|
||||
def nll_loss(output,target,weight=None,ignore_index=-100,reduction='mean'):
|
||||
assert output.ndim<=2 and output.ndim>0 and target.ndim==1
|
||||
n_classes = output.shape[-1]
|
||||
assert weight is None or weight.numel()==n_classes
|
||||
assert ignore_index<0 or ignore_index<n_classes
|
||||
if weight is None:
|
||||
weight = jt.ones((n_classes,))
|
||||
if ignore_index>0:
|
||||
weight[ignore_index]=0
|
||||
if output.ndim==2:
|
||||
index = jt.index((output.shape[0],),dim=0)
|
||||
loss = -output[index,target]*weight[target]
|
||||
else:
|
||||
loss = -output[target[0]]*weight[target[0]]
|
||||
if reduction=="mean":
|
||||
total_weight = weight[target].sum() if output.ndim==2 else weight[target[0]].sum()
|
||||
return loss.sum()/total_weight
|
||||
elif reduction=="sum":
|
||||
return loss.sum()
|
||||
elif reduction=="none":
|
||||
return loss
|
||||
else:
|
||||
raise ValueError(f'not support {reduction}')
|
||||
|
||||
class CrossEntropyLoss(Module):
|
||||
def __init__(self,ignore_index=None):
|
||||
self.ignore_index = ignore_index
|
||||
|
@ -330,6 +363,9 @@ class Dropout(Module):
|
|||
output = output * noise / (1.0 - self.p) # div keep prob
|
||||
return output
|
||||
|
||||
def dropout(x,p=0.5,is_train=False):
|
||||
return Dropout(p=p,is_train=is_train)(x)
|
||||
|
||||
class Linear(Module):
|
||||
def __init__(self, in_features, out_features, bias=True):
|
||||
self.in_features = in_features
|
||||
|
@ -707,6 +743,45 @@ class ConvTranspose(Module):
|
|||
y = y + b
|
||||
return y
|
||||
|
||||
def conv_transpose(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
|
||||
x = input
|
||||
N,C,H,W = x.shape
|
||||
i,o,h,w = weight.shape
|
||||
assert C==i
|
||||
assert groups==1, "Group conv not supported yet."
|
||||
stride = stride if isinstance(stride, tuple) else (stride, stride)
|
||||
dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation)
|
||||
# added
|
||||
padding = padding if isinstance(padding, tuple) else (padding, padding)
|
||||
output_padding = output_padding if isinstance (output_padding, tuple) else (output_padding, output_padding)
|
||||
assert output_padding[0] < max(stride[0], dilation[0]) and \
|
||||
output_padding[1] < max(stride[1], dilation[1]), \
|
||||
"output padding must be smaller than max(stride, dilation)"
|
||||
|
||||
stride_h, stride_w = stride
|
||||
padding_h, padding_w = padding
|
||||
dilation_h, dilation_w = dilation
|
||||
|
||||
h_out = (H-1) * stride_h + output_padding[0] - 2*padding_h + 1 + (h-1)*dilation_h
|
||||
w_out = (W-1) * stride_w + output_padding[1] - 2*padding_w + 1 + (w-1)*dilation_w
|
||||
out_shape = (N, o, h_out, w_out)
|
||||
shape = (N, i, o, H, W, h, w)
|
||||
xx = x.broadcast(shape, (2, 5, 6)) # i,h,w
|
||||
ww = weight.broadcast(shape, (0, 3, 4)) # N,H,W
|
||||
y = (ww*xx).reindex_reduce("add", out_shape, [
|
||||
'i0', # N
|
||||
'i2', # o
|
||||
f'i3*{stride_h}-{padding_h}+i5*{dilation_h}', # Hid+Khid
|
||||
f'i4*{stride_w}-{padding_w}+i6*{dilation_w}', # Wid+KWid
|
||||
])
|
||||
if isinstance(bias, jt.Var):
|
||||
b = bias.broadcast(y.shape, [0,2,3])
|
||||
y = y + b
|
||||
else:
|
||||
assert not bias, "Bias should be none or jittor var"
|
||||
return y
|
||||
|
||||
conv_transpose2d = conv_transpose
|
||||
|
||||
def pad(x,padding, mode='constant', value=0):
|
||||
assert mode in ['constant','replicate','reflect','circular'],'only support constant,replicate,reflect,circular pad'
|
||||
|
@ -865,8 +940,9 @@ class Sigmoid(Module):
|
|||
def execute(self, x) :
|
||||
return x.sigmoid()
|
||||
|
||||
def softplus(x,beta=1,threshold=20):
|
||||
return 1 / beta * jt.log(1 + (beta * x).exp())
|
||||
def softplus(x,beta=1.0,threshold=20.0):
|
||||
return 1 / beta * jt.log(1 + (beta * x).minimum(threshold).exp()) + \
|
||||
(x - threshold/beta).maximum(0.0)
|
||||
|
||||
def hardtanh(x,min_val=-1,max_val=1):
|
||||
return jt.clamp(x,min_v=min_val,max_v=max_val)
|
||||
|
@ -887,7 +963,7 @@ class Softplus(Module):
|
|||
self.threshold = threshold
|
||||
|
||||
def execute(self, x):
|
||||
return 1 / self.beta * jt.log(1 + (self.beta * x).exp())
|
||||
return softplus(x, self.beta, self.threshold)
|
||||
|
||||
class Resize(Module):
|
||||
def __init__(self, size, mode="nearest", align_corners=False):
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
# Maintainers:
|
||||
# Guowei Yang <471184555@qq.com>
|
||||
# Guoye Yang <498731903@qq.com>
|
||||
# Wenyang Zhou <576825820@qq.com>
|
||||
# Meng-Hao Guo <guomenghao1997@gmail.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
#
|
||||
# All Rights Reserved.
|
||||
#
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
@ -33,11 +34,15 @@ class Optimizer(object):
|
|||
assert isinstance(pg, dict)
|
||||
self.param_groups.append(pg)
|
||||
self.n_step = 0
|
||||
|
||||
def add_param_group(self, group):
|
||||
self.param_groups.append(group)
|
||||
|
||||
@property
|
||||
def defaults(self):
|
||||
exclude = set(("defaults", "param_groups", "n_step"))
|
||||
return { k:v for k, v in self.__dict__.items() if k[0] != '_' and k not in exclude }
|
||||
exclude = set(("defaults", "param_groups", "n_step", "pre_step", "step"))
|
||||
return { k:v for k, v in self.__dict__.items()
|
||||
if k[0] != '_' and k not in exclude and not callable(v) }
|
||||
|
||||
def pre_step(self, loss):
|
||||
""" something should be done before step, such as calc gradients, mpi sync, and so on.
|
||||
|
@ -115,6 +120,12 @@ class SGD(Optimizer):
|
|||
for p in pg["params"]:
|
||||
values.append(jt.zeros(p.shape, p.dtype).stop_grad())
|
||||
|
||||
def add_param_group(self, group):
|
||||
values = group["values"] = []
|
||||
for p in group["params"]:
|
||||
values.append(jt.zeros(p.shape, p.dtype).stop_grad())
|
||||
self.param_groups.append(group)
|
||||
|
||||
def step(self, loss=None):
|
||||
if loss is not None:
|
||||
self.pre_step(loss)
|
||||
|
@ -159,6 +170,12 @@ class RMSprop(Optimizer):
|
|||
for p in pg["params"]:
|
||||
values.append(jt.zeros(p.shape, p.dtype).stop_grad())
|
||||
|
||||
def add_param_group(self, group):
|
||||
values = group["values"] = []
|
||||
for p in group["params"]:
|
||||
values.append(jt.zeros(p.shape, p.dtype).stop_grad())
|
||||
self.param_groups.append(group)
|
||||
|
||||
def step(self, loss=None):
|
||||
if loss is not None:
|
||||
self.pre_step(loss)
|
||||
|
@ -195,6 +212,14 @@ class Adam(Optimizer):
|
|||
values.append(jt.zeros(p.shape, p.dtype).stop_grad())
|
||||
m.append(jt.zeros(p.shape, p.dtype).stop_grad())
|
||||
|
||||
def add_param_group(self, group):
|
||||
values = group["values"] = []
|
||||
m = group["m"] = []
|
||||
for p in group["params"]:
|
||||
values.append(jt.zeros(p.shape, p.dtype).stop_grad())
|
||||
m.append(jt.zeros(p.shape, p.dtype).stop_grad())
|
||||
self.param_groups.append(group)
|
||||
|
||||
def step(self, loss=None):
|
||||
if loss is not None:
|
||||
self.pre_step(loss)
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
# Maintainers:
|
||||
# Guowei Yang <471184555@qq.com>
|
||||
# Wenyang Zhou <576825820@qq.com>
|
||||
# Meng-Hao Guo <guomenghao1997@gmail.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
#
|
||||
# All Rights Reserved.
|
||||
#
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
@ -30,11 +31,14 @@ class Pool(Module):
|
|||
if self.ceil_mode == False:
|
||||
h = (H+self.padding*2-self.kernel_size)//self.stride+1
|
||||
w = (W+self.padding*2-self.kernel_size)//self.stride+1
|
||||
use_code_op = self.op in ['maximum', 'minimum']
|
||||
# some second order avg_pool is require, so we don't use code op here
|
||||
else:
|
||||
h = (H+self.padding*2-self.kernel_size + self.stride - 1)//self.stride+1
|
||||
w = (W+self.padding*2-self.kernel_size + self.stride - 1)//self.stride+1
|
||||
use_code_op = self.op in ['maximum', 'minimum', 'mean']
|
||||
|
||||
if self.op in ['maximum', 'minimum', 'mean']:
|
||||
if use_code_op:
|
||||
if self.op == 'mean':
|
||||
if self.count_include_pad:
|
||||
count = f"int count = {self.kernel_size*self.kernel_size};"
|
||||
|
@ -187,5 +191,25 @@ class AdaptiveAvgPool2d(Module):
|
|||
])
|
||||
return xx.reduce("mean", [4,5])
|
||||
|
||||
def pool(x, kernel_size, op, padding=0, stride = 1):
|
||||
return Pool(kernel_size, stride, padding, op=op)(x)
|
||||
def pool(x, kernel_size, op, padding=0, stride=None):
|
||||
return Pool(kernel_size, stride, padding, op=op)(x)
|
||||
|
||||
class AvgPool2d(Module):
|
||||
def __init__(self, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True):
|
||||
self.layer = Pool(kernel_size=kernel_size, stride=stride, padding=padding, ceil_mode=ceil_mode, count_include_pad=count_include_pad, op="mean")
|
||||
|
||||
def execute(self, x):
|
||||
return self.layer(x)
|
||||
|
||||
def avg_pool2d(x, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True):
|
||||
return AvgPool2d(kernel_size, stride, padding, ceil_mode, count_include_pad)(x)
|
||||
|
||||
class MaxPool2d(Module):
|
||||
def __init__(self, kernel_size, stride=None, padding=0, dilation=None, return_indices=None, ceil_mode=False):
|
||||
self.layer = Pool(kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, return_indices=return_indices, ceil_mode=ceil_mode, op="maximum")
|
||||
|
||||
def execute(self, x):
|
||||
return self.layer(x)
|
||||
|
||||
def max_pool2d(x, kernel_size, stride=None, padding=0, dilation=None, return_indices=None, ceil_mode=False):
|
||||
return MaxPool2d(kernel_size, stride, padding, dilation, return_indices, ceil_mode)(x)
|
|
@ -1,5 +1,6 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
# Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
# Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
# Maintainers:
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# Xiangli Li <190569238@qq.com>
|
||||
#
|
||||
#
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
|
||||
class SparseVar:
|
||||
def __init__(self,indices,values,shape):
|
||||
assert isinstance(indices,jt.Var) and isinstance(values,jt.Var) and isinstance(shape,jt.NanoVector)
|
||||
self.indices = indices
|
||||
self.values = values
|
||||
self.shape = shape
|
||||
self.ndim = len(shape)
|
||||
|
||||
def _indices(self):
|
||||
return self.indices
|
||||
|
||||
def _values(self):
|
||||
return self.values
|
||||
|
||||
def t(self):
|
||||
indices = list(self.indices.split(1,dim=0))
|
||||
indices[-1],indices[-2] = indices[-2],indices[-1]
|
||||
indices = jt.contrib.concat(indices,dim=0)
|
||||
shape = list(self.shape)
|
||||
shape[-1],shape[-2] = shape[-2],shape[-1]
|
||||
shape = jt.NanoVector(shape)
|
||||
return SparseVar(indices,self.values,shape)
|
||||
|
||||
def to_dense(self):
|
||||
ret = jt.zeros(self.shape,self.values.dtype)
|
||||
indices = tuple(self.indices.split(1,dim=0))
|
||||
ret[indices]=self.values
|
||||
return ret
|
||||
|
||||
def sparse_array(indices,values,shape):
|
||||
return SparseVar(indices,values,shape)
|
||||
|
||||
def spmm(spase_x,y):
|
||||
assert isinstance(spase_x,SparseVar) and isinstance(y,jt.Var)
|
||||
assert spase_x.ndim==2 and y.ndim==2 and spase_x.shape[-1]==y.shape[0]
|
||||
|
||||
# TODO
|
||||
x = spase_x.to_dense()
|
||||
return jt.matmul(x,y)
|
||||
|
|
@ -1,5 +1,6 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
# Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
# Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
|
|
@ -28,7 +28,7 @@ RUN pip3 install jittor --timeout 100 && python3.7 -m jittor.test.test_example
|
|||
RUN pip3 uninstall jittor -y
|
||||
|
||||
COPY . jittor
|
||||
RUN python3.7 -m pip install jittor
|
||||
RUN python3.7 -m pip install ./jittor
|
||||
RUN python3.7 -m jittor.test.test_core
|
||||
EOF
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ RUN pip3 install jittor --timeout 100 && python3.7 -m jittor.test.test_example
|
|||
RUN pip3 uninstall jittor -y
|
||||
|
||||
COPY . jittor
|
||||
RUN python3.7 -m pip install jittor
|
||||
RUN python3.7 -m pip install ./jittor
|
||||
RUN python3.7 -m jittor.test.test_core
|
||||
EOF
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ RUN pip3 install jittor --timeout 100 && python3.7 -m jittor.test.test_example
|
|||
RUN pip3 uninstall jittor -y
|
||||
|
||||
COPY . jittor
|
||||
RUN python3.7 -m pip install jittor
|
||||
RUN python3.7 -m pip install ./jittor
|
||||
RUN python3.7 -m jittor.test.test_core
|
||||
EOF
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ RUN pip3 install jittor --timeout 100 && python3.7 -m jittor.test.test_example
|
|||
RUN pip3 uninstall jittor -y
|
||||
|
||||
COPY . jittor
|
||||
RUN python3.7 -m pip install jittor
|
||||
RUN python3.7 -m pip install ./jittor
|
||||
RUN python3.7 -m jittor.test.test_core
|
||||
EOF
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@ RUN pip3 install jittor --timeout 100 && python3 -m jittor.test.test_example
|
|||
RUN pip3 uninstall jittor -y
|
||||
|
||||
COPY . jittor
|
||||
RUN python3 -m pip install jittor
|
||||
RUN python3 -m pip install ./jittor
|
||||
RUN python3 -m jittor.test.test_core
|
||||
EOF
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ RUN pip3 install jittor --timeout 100 && python3.7 -m jittor.test.test_example
|
|||
RUN pip3 uninstall jittor -y
|
||||
|
||||
COPY . jittor
|
||||
RUN python3.7 -m pip install jittor
|
||||
RUN python3.7 -m pip install ./jittor
|
||||
RUN python3.7 -m jittor.test.test_core
|
||||
EOF
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
# Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
# Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
# Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
# Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
# Maintainers:
|
||||
# Guoye Yang <498731903@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
#
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
|
|
@ -5,7 +5,9 @@
|
|||
# ***************************************************************
|
||||
import unittest
|
||||
import jittor as jt
|
||||
from jittor.nn import Pool, pool
|
||||
from jittor.nn import Pool, pool, AvgPool2d, avg_pool2d
|
||||
from jittor.nn import MaxPool2d as j_MaxPool2d
|
||||
from jittor.nn import max_pool2d as j_max_pool2d
|
||||
import numpy as np
|
||||
from .test_core import expect_error
|
||||
from .test_grad import ngrad
|
||||
|
@ -101,7 +103,7 @@ class TestArgPoolOp(unittest.TestCase):
|
|||
check(jt_model, torch_model, shape, False)
|
||||
for i in range(10):
|
||||
check(jt_model, torch_model, [1,1,300,300], True)
|
||||
|
||||
|
||||
def test_cpu_(self):
|
||||
# x = jt.random([32, 128, 157, 300])
|
||||
x = jt.random([4, 128, 157, 300])
|
||||
|
@ -138,5 +140,50 @@ class TestArgPoolOp(unittest.TestCase):
|
|||
shape = (2, 16, 33, 33)
|
||||
check(jt_model, torch_model, shape, False)
|
||||
|
||||
def test_AvgPool2d(self):
|
||||
from torch.nn import AvgPool2d as t_AvgPool2d
|
||||
jt_model = AvgPool2d(3, 1, 1, ceil_mode=True)
|
||||
torch_model = t_AvgPool2d(3, 1, 1, ceil_mode=True)
|
||||
shape = (2, 16, 33, 33)
|
||||
check(jt_model, torch_model, shape, False)
|
||||
|
||||
jt_model = AvgPool2d(3, 1, 1, ceil_mode=True, count_include_pad=False)
|
||||
torch_model = t_AvgPool2d(3, 1, 1, ceil_mode=True, count_include_pad=False)
|
||||
shape = (2, 16, 100, 100)
|
||||
check(jt_model, torch_model, shape, False)
|
||||
print('finish')
|
||||
|
||||
def test_avg_pool2d(self):
|
||||
from torch.nn.functional import avg_pool2d as t_avg_pool2d
|
||||
arr = np.random.random((2, 16, 33, 33))
|
||||
jt_model = avg_pool2d(jt.array(arr), 3, 1, 1, ceil_mode=True)
|
||||
torch_model = t_avg_pool2d(torch.Tensor(arr), 3, 1, 1, ceil_mode=True)
|
||||
assert np.allclose(jt_model.numpy(), torch_model.numpy())
|
||||
|
||||
jt_model = avg_pool2d(jt.array(arr), 3, 1, 1, ceil_mode=True, count_include_pad=False)
|
||||
torch_model = t_avg_pool2d(torch.Tensor(arr), 3, 1, 1, ceil_mode=True, count_include_pad=False)
|
||||
assert np.allclose(jt_model.numpy(), torch_model.numpy())
|
||||
print('finish')
|
||||
|
||||
def test_MaxPool2d(self):
|
||||
from torch.nn import MaxPool2d
|
||||
jt_model = j_MaxPool2d(3, 1, 1, ceil_mode=True)
|
||||
torch_model = MaxPool2d(3, 1, 1, ceil_mode=True)
|
||||
shape = (2, 16, 33, 33)
|
||||
check(jt_model, torch_model, shape, False)
|
||||
print('finish')
|
||||
|
||||
def test_max_pool2d(self):
|
||||
from torch.nn.functional import max_pool2d
|
||||
arr = np.random.random((2, 16, 33, 33))
|
||||
jt_model = j_max_pool2d(jt.array(arr), 3, 1, 1, ceil_mode=True)
|
||||
torch_model = max_pool2d(torch.Tensor(arr), 3, 1, 1, ceil_mode=True)
|
||||
assert np.allclose(jt_model.numpy(), torch_model.numpy())
|
||||
|
||||
jt_model = j_max_pool2d(jt.array(arr), 3, 1, 1)
|
||||
torch_model = max_pool2d(torch.Tensor(arr), 3, 1, 1)
|
||||
assert np.allclose(jt_model.numpy(), torch_model.numpy())
|
||||
print('finish')
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -1,8 +1,9 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
# Maintainers:
|
||||
# Guoye Yang <498731903@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
#
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
# Maintainers:
|
||||
# Guoye Yang <498731903@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
#
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
# Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
# Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
# Maintainers:
|
||||
# Guowei Yang <471184555@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
#
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
# Maintainers:
|
||||
# Guowei Yang <471184555@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
#
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
|
||||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
# Maintainers:
|
||||
# Guowei Yang <471184555@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
#
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
# Authors: Dun Liang <randonlang@gmail.com>.
|
||||
# Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
|
||||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
# Maintainers:
|
||||
# Wenyang Zhou <576825820@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
#
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
# Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
# Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
# Maintainers:
|
||||
# Meng-Hao Guo <guomenghao1997@gmail.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
#
|
||||
# All Rights Reserved.
|
||||
#
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue